diff --git a/src/lib.rs b/src/lib.rs index db6b367..23ecc67 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,17 +4,18 @@ use once_cell::sync::Lazy; use proc_macro as pc; use proc_macro2::TokenStream; use quote::{quote, ToTokens}; -use std::{collections::BTreeMap, fmt, io::Write, sync::RwLock}; +use std::{ + collections::BTreeMap, + fmt, + io::Write, + sync::{Arc, RwLock}, +}; use syn::{parenthesized, parse::Parse, spanned::Spanned}; fn s_err(span: proc_macro2::Span, msg: impl fmt::Display) -> syn::Error { syn::Error::new(span, msg) } -fn lock_err(span: proc_macro2::Span, e: impl fmt::Display) -> syn::Error { - s_err(span, format!("Failed to acquire lock: {}", e)) -} - fn logic_err(span: proc_macro2::Span) -> syn::Error { s_err( span, @@ -22,12 +23,64 @@ fn logic_err(span: proc_macro2::Span) -> syn::Error { ) } +#[derive(Default)] +struct GlobalState { + routes: RwLock>, + structs: RwLock>>, + functions: RwLock>, +} + +impl GlobalState { + fn instance() -> &'static Arc { + static INSTANCE: Lazy> = Lazy::new(|| Arc::new(GlobalState::default())); + &INSTANCE + } + + fn add_route(route: Route) { + let state = GlobalState::instance(); + let mut routes = state.routes.write().unwrap(); + routes.push(route); + } + + fn add_struct(name: String, fields: Vec) { + let state = GlobalState::instance(); + let mut structs = state.structs.write().unwrap(); + structs.insert(name, fields); + } + + fn add_function(name: String, function: Function) { + let state = GlobalState::instance(); + let mut functions = state.functions.write().unwrap(); + functions.insert(name, function); + } + + fn get_routes() -> Vec { + let state = GlobalState::instance(); + let routes = state.routes.read().unwrap(); + routes.clone() + } + + fn get_struct(name: &str) -> Option> { + let state = GlobalState::instance(); + let structs = state.structs.read().unwrap(); + structs.get(name).cloned() + } + + fn get_function(name: &str) -> Option { + let state = GlobalState::instance(); + let functions = state.functions.read().unwrap(); + functions.get(name).cloned() + } +} + +#[derive(Clone)] struct Route { route: String, method: String, fn_name: String, } +#[derive(Clone)] struct Function { params: BTreeMap, response: String, @@ -39,12 +92,6 @@ struct StructField { ty: String, } -static ROUTES: Lazy>> = Lazy::new(|| RwLock::new(Vec::new())); -static STRUCTS: Lazy>>> = - Lazy::new(|| RwLock::new(BTreeMap::new())); -static FUNCTIONS: Lazy>> = - Lazy::new(|| RwLock::new(BTreeMap::new())); - /// Adds a route to the router. Use for each api endpoint you want to expose to the frontend. /// `Inline Functions` are currently not supported. #[proc_macro] @@ -71,7 +118,7 @@ fn add_route_inner(input: TokenStream) -> syn::Result { .ident .to_string(); - ROUTES.write().map_err(|e| lock_err(span, e))?.push(Route { + GlobalState::add_route(Route { route: route.clone(), method: method.to_string(), fn_name, @@ -147,9 +194,7 @@ fn api_inner(input: TokenStream) -> syn::Result { let args = syn::parse2::(input)?; let path = args.path.value(); - let routes = ROUTES.read().map_err(|e| lock_err(span, e))?; - let functions = FUNCTIONS.read().map_err(|e| lock_err(span, e))?; - let structs = STRUCTS.read().map_err(|e| lock_err(span, e))?; + let routes = GlobalState::get_routes(); let mut ts_functions = BTreeMap::new(); let mut ts_interfaces = BTreeMap::new(); @@ -159,7 +204,7 @@ fn api_inner(input: TokenStream) -> syn::Result { let method = &route.method; let url = &route.route; - let function = functions.get(fn_name).ok_or_else(|| { + let function = GlobalState::get_function(fn_name).ok_or_else(|| { s_err( span, format!( @@ -169,10 +214,9 @@ fn api_inner(input: TokenStream) -> syn::Result { ) })?; - let ty = collect_params(function, &structs, span, &mut ts_interfaces)?; + let ty = collect_params(&function, span, &mut ts_interfaces)?; - let response_type = - collect_response_type(&function.response, &structs, span, &mut ts_interfaces)?; + let response_type = collect_response_type(&function.response, span, &mut ts_interfaces)?; let params_str = if !ty.is_empty() { format!("params: {}", ty) @@ -217,14 +261,13 @@ fn api_inner(input: TokenStream) -> syn::Result { fn collect_params( function: &Function, - structs: &BTreeMap>, span: proc_macro2::Span, ts_interfaces: &mut BTreeMap, ) -> syn::Result { for param in &function.params { if param.1.contains("Json") { let struct_name = extract_struct_name(span, param.1)?; - if let Some(fields) = structs.get(&struct_name).cloned() { + if let Some(fields) = GlobalState::get_struct(&struct_name) { ts_interfaces .entry(struct_name.clone()) .or_insert_with(|| generate_ts_interface(&struct_name.clone(), fields)); @@ -250,18 +293,17 @@ fn collect_params( fn collect_response_type( response: &str, - structs: &BTreeMap>, span: proc_macro2::Span, ts_interfaces: &mut BTreeMap, ) -> syn::Result { - let response = response.replace(" ", ""); + let response = response.replace(' ', ""); if let Some(response_type) = convert_rust_type_to_ts(&response) { return Ok(response_type); } if response.contains("Json") { let struct_name = extract_struct_name(span, &response)?; - if let Some(fields) = structs.get(&struct_name).cloned() { + if let Some(fields) = GlobalState::get_struct(&struct_name) { ts_interfaces .entry(struct_name.clone()) .or_insert_with(|| generate_ts_interface(&struct_name, fields)); @@ -378,20 +420,17 @@ fn cached_inner(args: TokenStream, input: TokenStream) -> syn::Result { - STRUCTS - .write() - .map_err(|e| lock_err(span, e))? - .insert(ident.to_string(), { - let mut field_vec = Vec::new(); - - for field in fields { - let ident = field.ident.ok_or_else(|| logic_err(span))?.to_string(); - let ty = field.ty.into_token_stream().to_string(); - field_vec.push(StructField { ident, ty }); - } - - field_vec - }); + GlobalState::add_struct(ident.to_string(), { + let mut field_vec = Vec::new(); + + for field in fields { + let ident = field.ident.ok_or_else(|| logic_err(span))?.to_string(); + let ty = field.ty.into_token_stream().to_string(); + field_vec.push(StructField { ident, ty }); + } + + field_vec + }); } syn::Item::Fn(item_fn) => { let fn_name = item_fn.sig.ident.to_string(); @@ -401,8 +440,8 @@ fn cached_inner(args: TokenStream, input: TokenStream) -> syn::Result "()".to_string(), }; - FUNCTIONS.write().map_err(|e| lock_err(span, e))?.insert( - fn_name.clone(), + GlobalState::add_function( + fn_name, Function { params: { let mut map = BTreeMap::new(); @@ -411,20 +450,20 @@ fn cached_inner(args: TokenStream, input: TokenStream) -> syn::Result { if pat.to_token_stream().to_string() == "Json" { let struct_path = ty.to_token_stream().to_string(); - let struct_name = struct_path.split("::").last().ok_or_else(|| logic_err(span))?.trim(); - let fields = STRUCTS - .read().map_err(|e| lock_err(span, e))? - .get(&struct_name.to_string()) - .ok_or_else(|| { - s_err( - span, - format!( - "Struct '{}' not found in the cache, mind adding it with #[cached]", - struct_name - ), - ) - })? - .clone(); + let struct_name = struct_path + .split("::") + .last() + .ok_or_else(|| logic_err(span))? + .trim(); + let fields = GlobalState::get_struct(struct_name).ok_or_else(|| { + s_err( + span, + format!( + "Struct '{}' not found in the cache, mind adding it with #[cached]", + struct_name + ), + ) + })?; for StructField { ident, ty } in fields { map.insert(ident, ty);