diff --git a/Cargo.toml b/Cargo.toml index 2ac935b..b2b4f88 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,9 +1,9 @@ [package] name = "gluer" -version = "0.1.0" +version = "0.2.0" edition = "2021" authors = ["Nils Wrenger "] -description = "A wrapper for rust frameworks which addresses the persistent issue of redundant type definitions between the frontend and backend" +description = "A wrapper for rust frameworks which addresses the persistent issue of redundant type and function definitions between the frontend and backend" keywords = ["parser", "api", "macro"] categories = ["accessibility", "web-programming", "api-bindings"] rust-version = "1.64.0" @@ -26,3 +26,4 @@ serde_yaml = "0.9.34" [dev-dependencies] axum = "0.7.5" tokio = { version = "1.39.2", features = ["macros", "rt-multi-thread"] } +serde = { version = "1.0", features = ["derive"] } diff --git a/README.md b/README.md index 603678d..f609d8e 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![crates.io](https://img.shields.io/crates/d/gluer.svg)](https://crates.io/crates/gluer) [![docs.rs](https://docs.rs/gluer/badge.svg)](https://docs.rs/gluer) -A wrapper for rust frameworks which addresses the persistent issue of redundant type definitions between the frontend and backend. At present, it exclusively supports the `axum` framework. +A wrapper for rust frameworks which addresses the persistent issue of redundant type and function definitions between the frontend and backend. At present, it exclusively supports the `axum` framework. ## Installation @@ -17,7 +17,12 @@ light_magic = "0.1.0" ## Disclaimer -Please be informed that this crate is in a very early state and is expected to work in not every case. Open a Issue if you encounter one! +Please be informed that this crate is in a very early state and is expected to work in not every case. Open a Issue if you encounter one! What works is: + +- Defining the routing and api generation as outlined in [How to use](#how-to-use) +- Inferring the input and output types of functions (but only `Json<...>` for inputs) +- Converting them to ts types +- Generating the ts file with the functions and data types ## How to use @@ -27,53 +32,54 @@ Firstly you have to use the `add_route!` macro when adding api important routes use axum::{ routing::{get, post}, Router, + Json, }; use gluer::add_route; -async fn root() -> &'static str { - "Hello, World!" +async fn root() -> Json<&'static str> { + "Hello, World!".into() } -let mut app = Router::new(); +let mut app: Router<()> = Router::new(); // Not api important, so adding without macro -app = app.router(get(root)); +app = app.route("/", get(root)); -// You currently cannot use inline functions, just path to the functions inside the methods (meaning `path(|| async &'static "Hello World!")` won't work!) +// You cannot use inline functions because of rust limitations of inferring types in macros add_route!(app, "/", post(root)); add_route!(app, "/user", post(root).delete(root)); ``` -Then you only have to use the `gen_spec!` which generates after specifying title, version and path the openapi doc on comptime: +Then you only have to use the `gen_spec!` macro which generates after specifying the path the api on comptime: ```rust use gluer::gen_spec; -gen_spec!("test", "0.1.0", "tests/test.json"); +gen_spec!("tests/api.ts"); ``` -Then use a library like `openapi-typescript` to generate your fitting `TS Client` code! - ### Complete Example -```rust -use axum::{ - routing::{get, post}, - Router, -}; +```rust,no_run +use axum::{routing::post, Json, Router}; use gluer::{add_route, gen_spec}; -async fn root() -> &'static str { - "Hello, World!" +#[derive(serde::Deserialize)] +struct Hello { + _name: String, +} + +async fn root(Json(_hello): Json) -> Json<&'static str> { + "Hello World!".into() } #[tokio::main] async fn main() { let mut app = Router::new(); - add_route!(app, "/", get(root).post(root)); + add_route!(app, "/", post(root)); - gen_spec!("test", "0.1.0", "tests/test.yaml"); + gen_spec!("tests/api.ts"); let listener = tokio::net::TcpListener::bind("127.0.0.1:8080") .await diff --git a/src/extractors.rs b/src/extractors.rs new file mode 100644 index 0000000..eb8d23a --- /dev/null +++ b/src/extractors.rs @@ -0,0 +1,120 @@ +use std::{collections::HashMap, env::current_dir}; + +use quote::ToTokens; + +pub(crate) fn extract_function( + fn_name: &str, + file_path: std::path::PathBuf, +) -> syn::Result<(Vec, String)> { + let source = std::fs::read_to_string(file_path) + .map_err(|e| syn::Error::new(proc_macro2::Span::mixed_site(), e.to_string()))?; + let syntax = syn::parse_file(&source)?; + + let mut params_map: HashMap> = HashMap::new(); + let mut responses_map: HashMap = HashMap::new(); + + for item in syntax.items { + if let syn::Item::Fn(syn::ItemFn { sig, .. }) = item { + let fn_name = sig.ident.to_string(); + let params: Vec = sig.inputs.iter().cloned().collect(); + params_map.insert(fn_name.clone(), params); + + let ty: String = match sig.output { + syn::ReturnType::Default => "()".to_string(), + syn::ReturnType::Type(_, ty) => ty.into_token_stream().to_string(), + }; + + responses_map.insert(fn_name, ty); + } + } + + let params = params_map.get(fn_name).cloned().ok_or_else(|| { + syn::Error::new( + proc_macro2::Span::call_site(), + "Function parameters not found", + ) + })?; + + let responses = responses_map.get(fn_name).cloned().ok_or_else(|| { + syn::Error::new( + proc_macro2::Span::call_site(), + "Function responses not found", + ) + })?; + + Ok((params, responses)) +} + +pub(crate) fn extract_struct( + struct_name: &str, + file_path: std::path::PathBuf, +) -> syn::Result> { + let source = std::fs::read_to_string(&file_path) + .map_err(|e| syn::Error::new(proc_macro2::Span::mixed_site(), e.to_string()))?; + let syntax = syn::parse_file(&source)?; + + for item in syntax.items { + if let syn::Item::Struct(syn::ItemStruct { ident, fields, .. }) = item { + let name = ident.to_string().trim().to_string(); + + if name == struct_name { + let mut field_vec = Vec::new(); + + if let syn::Fields::Named(fields) = fields { + for field in fields.named { + let field_name = field.ident.unwrap().to_string(); + let field_type = field.ty.into_token_stream().to_string(); + field_vec.push((field_name, field_type)); + } + } + + return Ok(field_vec); + } + } + } + + Err(syn::Error::new( + proc_macro2::Span::call_site(), + "Struct definition not found in ".to_string() + file_path.to_string_lossy().as_ref(), + )) +} + +pub(crate) fn resolve_path(segments: Vec) -> syn::Result { + let current_dir = current_dir().map_err(|_| { + syn::Error::new( + proc_macro2::Span::call_site(), + "Failed to get current directory", + ) + })?; + + if segments.len() == 1 { + // Function is in the same file, check if it's in main.rs or lib.rs (for tests) + let main_path = current_dir.join("src/main.rs"); + let lib_path = current_dir.join("src/lib.rs"); + if main_path.exists() { + Ok(main_path) + } else if lib_path.exists() { + Ok(current_dir.join("tests/main.rs")) + } else { + Err(syn::Error::new( + proc_macro2::Span::call_site(), + "Neither main.rs nor lib.rs found", + ))? + } + } else { + // Function is in a different module + let module_path = &segments[0]; + let file_path_mod = current_dir.join(format!("src/{}/mod.rs", module_path)); + let file_path_alt = current_dir.join(format!("src/{}.rs", module_path)); + if file_path_mod.exists() { + Ok(file_path_mod) + } else if file_path_alt.exists() { + Ok(file_path_alt) + } else { + Err(syn::Error::new( + proc_macro2::Span::call_site(), + format!("Module file not found for {}", module_path), + ))? + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 69ee24f..5261a6d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,20 +1,13 @@ #![doc = include_str!("../README.md")] +mod extractors; + +use crate::extractors::{extract_function, extract_struct, resolve_path}; use lazy_static::lazy_static; -use okapi::openapi3::{ - MediaType, Object, OpenApi, Operation, Parameter, ParameterValue, RefOr, Response, Responses, - SchemaObject, -}; use proc_macro as pc; use proc_macro2::TokenStream; use quote::{quote, ToTokens}; -use std::{ - collections::{BTreeMap, HashMap}, - env::current_dir, - fmt, - io::Write, - sync::Mutex, -}; +use std::{fmt, io::Write, sync::Mutex}; use syn::{parenthesized, parse::Parse, spanned::Spanned}; fn s_err(span: proc_macro2::Span, msg: impl fmt::Display) -> syn::Error { @@ -26,6 +19,7 @@ lazy_static! { } /// Adds a route to the router. Use for each api endpoint you want to expose to the frontend. +/// `Inline Functions` are not supported because of rust limitations of inferring types in macros. #[proc_macro] pub fn add_route(input: pc::TokenStream) -> pc::TokenStream { match add_route_inner(input.into()) { @@ -46,7 +40,7 @@ fn add_route_inner(input: TokenStream) -> syn::Result { method, r#fn, params, - responses, + response, } in &handler { let fn_name = r#fn.segments.last().unwrap().ident.to_string(); @@ -63,7 +57,7 @@ fn add_route_inner(input: TokenStream) -> syn::Result { method: method.to_string(), fn_name, params, - responses: responses.clone(), + response: response.clone(), }); } @@ -99,7 +93,7 @@ struct MethodCall { method: syn::Ident, r#fn: syn::Path, params: Vec, - responses: HashMap, + response: String, } impl Parse for MethodCall { @@ -109,61 +103,18 @@ impl Parse for MethodCall { parenthesized!(content in input); let r#fn: syn::Path = content.parse()?; - // Retrieve the function name as a string let fn_name = r#fn.segments.last().unwrap().ident.to_string(); + let segments = r#fn.segments.iter().map(|s| s.ident.to_string()).collect(); - // Determine the file to parse based on the path segments - let segments: Vec<_> = r#fn - .segments - .iter() - .map(|seg| seg.ident.to_string()) - .collect(); - let current_dir = current_dir().map_err(|_| { - syn::Error::new( - proc_macro2::Span::call_site(), - "Failed to get current directory", - ) - })?; - - let file_path = if segments.len() == 1 { - // Function is in the same file, check if it's in main.rs or lib.rs (for tests) - let main_path = current_dir.join("src/main.rs"); - let lib_path = current_dir.join("src/lib.rs"); - if main_path.exists() { - main_path - } else if lib_path.exists() { - current_dir.join("tests/main.rs") - } else { - return Err(syn::Error::new( - proc_macro2::Span::call_site(), - "Neither main.rs nor lib.rs found", - )); - } - } else { - // Function is in a different module - let module_path = &segments[0]; - let file_path_mod = current_dir.join(format!("src/{}/mod.rs", module_path)); - let file_path_alt = current_dir.join(format!("src/{}.rs", module_path)); - if file_path_mod.exists() { - file_path_mod - } else if file_path_alt.exists() { - file_path_alt - } else { - return Err(syn::Error::new( - proc_macro2::Span::call_site(), - format!("Module file not found for {}", module_path), - )); - } - }; + let file_path = resolve_path(segments)?; - let (params, responses) = - extract_function_params_and_responses_from_file(&fn_name, file_path.to_str().unwrap())?; + let (params, response) = extract_function(&fn_name, file_path)?; Ok(MethodCall { method, r#fn, params, - responses, + response, }) } } @@ -178,62 +129,16 @@ impl ToTokens for MethodCall { }); } } - -fn extract_function_params_and_responses_from_file( - fn_name: &str, - file_path: &str, -) -> syn::Result<(Vec, HashMap)> { - let source = std::fs::read_to_string(file_path) - .map_err(|e| syn::Error::new(proc_macro2::Span::mixed_site(), e.to_string()))?; - let syntax = syn::parse_file(&source)?; - - let mut params_map: HashMap> = HashMap::new(); - let mut responses_map: HashMap> = HashMap::new(); - - for item in syntax.items { - if let syn::Item::Fn(syn::ItemFn { sig, .. }) = item { - let fn_name = sig.ident.to_string(); - let params: Vec = sig.inputs.iter().cloned().collect(); - params_map.insert(fn_name.clone(), params); - - let mut responses = HashMap::new(); - let ty: String = match sig.output { - syn::ReturnType::Default => "()".to_string(), - syn::ReturnType::Type(_, ty) => ty.into_token_stream().to_string(), - }; - - responses.insert("200".to_string(), ty); - responses_map.insert(fn_name, responses); - } - } - - let params = params_map.get(fn_name).cloned().ok_or_else(|| { - syn::Error::new( - proc_macro2::Span::call_site(), - "Function parameters not found", - ) - })?; - - let responses = responses_map.get(fn_name).cloned().ok_or_else(|| { - syn::Error::new( - proc_macro2::Span::call_site(), - "Function responses not found", - ) - })?; - - Ok((params, responses)) -} - struct Route { route: String, method: String, fn_name: String, - // should be syn::Type but that's not thread safe + // Should be syn::Type but that's not thread safe params: Vec, - responses: HashMap, + response: String, } -/// Generates an OpenAPI spec from the routes added with `add_route!`. Specify the title, version of the spec, and path to save the spec to. +/// Generates an api ts file from the routes added with `add_route!`. Specify the path to save the api to. #[proc_macro] pub fn gen_spec(input: pc::TokenStream) -> pc::TokenStream { match gen_spec_inner(input.into()) { @@ -247,130 +152,132 @@ fn gen_spec_inner(input: TokenStream) -> syn::Result { let args = syn::parse2::(input)?; let path = args.path.value(); - let title = args.title.value(); - let version = args.version.value(); let routes = ROUTES.lock().unwrap(); - let mut openapi = OpenApi::new(); - openapi.info.title = title; - openapi.info.version = version; - for route in routes.iter() { - let mut path_item = openapi.paths.get(&route.route).cloned().unwrap_or_default(); - - let operation = Operation { - operation_id: Some(route.fn_name.clone() + "_" + &route.method), - parameters: route - .params - .iter() - .filter_map(|param| { - if param.starts_with("State<") { - None - } else { - Some(RefOr::Object(Parameter { - name: param.clone(), - location: route.method.clone(), - description: None, - required: true, - deprecated: false, - allow_empty_value: false, - value: ParameterValue::Schema { - style: None, - explode: None, - allow_reserved: true, - schema: SchemaObject::default(), - example: None, - examples: None, - }, - extensions: Object::default(), - })) - } - }) - .collect(), - request_body: None, - responses: Responses { - responses: route - .responses - .iter() - .map(|(status, ty)| { - ( - status.clone(), - RefOr::Object(Response { - description: format!("{} response", status), - content: { - let mut content = BTreeMap::new(); - content.insert( - "application/json".to_string(), - MediaType { - schema: Some(SchemaObject::new_ref(ty.to_string())), - example: None, - examples: None, - encoding: BTreeMap::default(), - extensions: BTreeMap::default(), - }, - ); - content - }, - headers: BTreeMap::default(), - links: BTreeMap::default(), - extensions: Object::default(), - }), - ) - }) - .collect(), - ..Default::default() - }, - deprecated: false, - security: None, - servers: None, - extensions: Object::default(), - ..Default::default() - }; + let mut ts_functions = String::new(); + let mut ts_interfaces = String::new(); - match route.method.as_str() { - "get" => path_item.get = Some(operation), - "post" => path_item.post = Some(operation), - "put" => path_item.put = Some(operation), - "delete" => path_item.delete = Some(operation), - _ => { - return Err(s_err( - span, - "Unsupported HTTP method ".to_string() + &route.method, - )) + for route in routes.iter() { + let fn_name = route.method.clone() + "_" + &route.fn_name; + let method = &route.method; + let url = &route.route; + + let mut param_names = vec![]; + let mut param_types = vec![]; + + for param in &route.params { + if param.contains("Json") { + let struct_name = param + .split('<') + .nth(1) + .unwrap() + .split('>') + .next() + .unwrap() + .trim(); + param_names.push("data".to_string()); + param_types.push(struct_name.to_string()); + + let file_path = + resolve_path(struct_name.split("::").map(|f| f.to_string()).collect())?; + + let interface = + generate_ts_interface(struct_name, extract_struct(struct_name, file_path)?); + ts_interfaces.push_str(&interface); + } else { + let param_name = param.split(':').next().unwrap().trim().to_string(); + let param_type = convert_rust_type_to_ts(param.split(':').nth(1).unwrap().trim()); + param_names.push(param_name.clone()); + param_types.push(param_type); } } - openapi.paths.insert(route.route.clone(), path_item); - } + let params_str = param_names + .iter() + .zip(param_types.iter()) + .map(|(name, ty)| format!("{}: {}", name, ty)) + .collect::>() + .join(", "); + let response_type = convert_rust_type_to_ts(&route.response.clone()); + + let body_assignment = if param_names.contains(&"data".to_string()) { + "JSON.stringify(data)" + } else { + "undefined" + }; - let spec_string = serde_yaml::to_string(&openapi) - .map_err(|e| s_err(span, format!("Failed to serialize OpenAPI spec: {}", e)))?; + let function = format!( + r#"export async function {fn_name}({params_str}): Promise<{response_type} | any> {{ + const response = await fetch("{url}", {{ + method: "{method}", + headers: {{ + "Content-Type": "application/json" + }}, + body: {body_assignment} + }}); + return response.json(); +}} + +"#, + fn_name = fn_name, + params_str = params_str, + response_type = response_type, + url = url, + method = method.to_uppercase(), + body_assignment = body_assignment + ); + + ts_functions.push_str(&function); + } let mut file = std::fs::File::create(path) .map_err(|e| s_err(span, format!("Failed to create file: {}", e)))?; - file.write_all(spec_string.as_bytes()) + file.write_all(ts_interfaces.as_bytes()) + .map_err(|e| s_err(span, format!("Failed to write to file: {}", e)))?; + file.write_all(ts_functions.as_bytes()) .map_err(|e| s_err(span, format!("Failed to write to file: {}", e)))?; Ok(quote! {}) } +fn convert_rust_type_to_ts(rust_type: &str) -> String { + let rust_type = rust_type.trim(); + match rust_type { + "str" | "String" => "string".to_string(), + "usize" | "isize" | "u8" | "u16" | "u32" | "u64" | "i8" | "i16" | "i32" | "i64" | "f32" + | "f64" => "number".to_string(), + "bool" => "boolean".to_string(), + "()" => "void".to_string(), + t if t.starts_with("Vec <") => format!("{}[]", convert_rust_type_to_ts(&t[5..t.len() - 1])), + t if t.starts_with("Option <") => convert_rust_type_to_ts(&t[8..t.len() - 1]), + t if t.starts_with("Result <") => convert_rust_type_to_ts(&t[8..t.len() - 1]), + t if t.starts_with("Json <") => convert_rust_type_to_ts(&t[6..t.len() - 1]), + t if t.starts_with('&') => convert_rust_type_to_ts(&t[1..]), + t if t.starts_with("'static") => convert_rust_type_to_ts(&t[8..]), + t => t.to_string(), + } +} + +fn generate_ts_interface(struct_name: &str, fields: Vec<(String, String)>) -> String { + let mut interface = format!("export interface {} {{\n", struct_name); + + for (field_name, field_type) in fields { + let field_type = convert_rust_type_to_ts(&field_type); + interface.push_str(&format!(" {}: {};\n", field_name, field_type)); + } + + interface.push_str("}\n\n"); + interface +} + struct GenArgs { - title: syn::LitStr, - version: syn::LitStr, path: syn::LitStr, } impl Parse for GenArgs { fn parse(input: syn::parse::ParseStream) -> syn::Result { - let title = input.parse()?; - ::parse(input)?; - let version = input.parse()?; - ::parse(input)?; let path = input.parse()?; - Ok(GenArgs { - title, - version, - path, - }) + Ok(GenArgs { path }) } } diff --git a/tests/api.ts b/tests/api.ts new file mode 100644 index 0000000..b161124 --- /dev/null +++ b/tests/api.ts @@ -0,0 +1,15 @@ +export interface Hello { + _name: string; +} + +export async function post_root(data: Hello): Promise { + const response = await fetch("/", { + method: "POST", + headers: { + "Content-Type": "application/json" + }, + body: JSON.stringify(data) + }); + return response.json(); +} + diff --git a/tests/main.rs b/tests/main.rs index ffc5535..0271739 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -1,8 +1,13 @@ -use axum::{routing::get, Router}; +use axum::{routing::post, Json, Router}; use gluer::{add_route, gen_spec}; -async fn root() -> &'static str { - "Hello, World!" +#[derive(serde::Deserialize)] +struct Hello { + _name: String, +} + +async fn root(Json(_hello): Json) -> Json<&'static str> { + "Hello World!".into() } #[tokio::test] @@ -10,9 +15,9 @@ async fn root() -> &'static str { async fn main_test() { let mut app = Router::new(); - add_route!(app, "/", get(root).post(root)); + add_route!(app, "/", post(root)); - gen_spec!("test", "0.1.0", "tests/test.yaml"); + gen_spec!("tests/api.ts"); let listener = tokio::net::TcpListener::bind("127.0.0.1:8080") .await diff --git a/tests/test.yaml b/tests/test.yaml deleted file mode 100644 index e4fe7f3..0000000 --- a/tests/test.yaml +++ /dev/null @@ -1,24 +0,0 @@ -openapi: 3.0.0 -info: - title: test - version: 0.1.0 -paths: - /: - get: - operationId: root_get - responses: - '200': - description: 200 response - content: - application/json: - schema: - $ref: '& ''static str' - post: - operationId: root_post - responses: - '200': - description: 200 response - content: - application/json: - schema: - $ref: '& ''static str'