From f70e20602db8b9dee999418f141cac91e7acd804 Mon Sep 17 00:00:00 2001 From: Manuel Raimann Date: Thu, 19 Sep 2024 16:40:45 +0200 Subject: [PATCH] Use enum to configure sync write mode --- cached_proc_macro/src/cached.rs | 71 +++++++++++++++------------------ cached_proc_macro/src/once.rs | 20 ++++++---- tests/cached.rs | 12 +++--- 3 files changed, 52 insertions(+), 51 deletions(-) diff --git a/cached_proc_macro/src/cached.rs b/cached_proc_macro/src/cached.rs index 21fbe2b..b4025d5 100644 --- a/cached_proc_macro/src/cached.rs +++ b/cached_proc_macro/src/cached.rs @@ -3,9 +3,18 @@ use darling::ast::NestedMeta; use darling::FromMeta; use proc_macro::TokenStream; use quote::quote; +use std::cmp::PartialEq; use syn::spanned::Spanned; use syn::{parse_macro_input, parse_str, Block, Ident, ItemFn, ReturnType, Type}; +#[derive(Debug, Default, FromMeta, Eq, PartialEq)] +enum SyncWriteMode { + #[default] + Disabled, + Default, + ByKey, +} + #[derive(FromMeta)] struct MacroArgs { #[darling(default)] @@ -27,9 +36,7 @@ struct MacroArgs { #[darling(default)] option: bool, #[darling(default)] - sync_writes: bool, - #[darling(default)] - sync_writes_by_key: bool, + sync_writes: SyncWriteMode, #[darling(default)] with_cached_flag: bool, #[darling(default)] @@ -192,16 +199,8 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { _ => panic!("the result and option attributes are mutually exclusive"), }; - if args.result_fallback && args.sync_writes { - panic!("the result_fallback and sync_writes attributes are mutually exclusive"); - } - - if args.result_fallback && args.sync_writes_by_key { - panic!("the result_fallback and sync_writes_by_key attributes are mutually exclusive"); - } - - if args.sync_writes && args.sync_writes_by_key { - panic!("the sync_writes and sync_writes_by_key attributes are mutually exclusive"); + if args.result_fallback && args.sync_writes != SyncWriteMode::Disabled { + panic!("result_fallback and sync_writes are mutually exclusive"); } let set_cache_and_return = quote! { @@ -216,8 +215,8 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { let function_call; let ty; if asyncness.is_some() { - lock = if args.sync_writes_by_key { - quote! { + lock = match args.sync_writes { + SyncWriteMode::ByKey => quote! { let mut locks = #cache_ident.lock().await; let lock = locks .entry(key.clone()) @@ -225,11 +224,10 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { .clone(); drop(locks); let mut cache = lock.lock().await; - } - } else { - quote! { + }, + _ => quote! { let mut cache = #cache_ident.lock().await; - } + }, }; function_no_cache = quote! { @@ -240,27 +238,25 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { let result = #no_cache_fn_ident(#(#input_names),*).await; }; - ty = if args.sync_writes_by_key { - quote! { + ty = match args.sync_writes { + SyncWriteMode::ByKey => quote! { #visibility static #cache_ident: ::cached::once_cell::sync::Lazy<::cached::async_sync::Mutex>>>> = ::cached::once_cell::sync::Lazy::new(|| ::cached::async_sync::Mutex::new(std::collections::HashMap::new())); - } - } else { - quote! { + }, + _ => quote! { #visibility static #cache_ident: ::cached::once_cell::sync::Lazy<::cached::async_sync::Mutex<#cache_ty>> = ::cached::once_cell::sync::Lazy::new(|| ::cached::async_sync::Mutex::new(#cache_create)); - } + }, }; } else { - lock = if args.sync_writes_by_key { - quote! { + lock = match args.sync_writes { + SyncWriteMode::ByKey => quote! { let mut locks = #cache_ident.lock().unwrap(); let lock = locks.entry(key.clone()).or_insert_with(|| std::sync::Arc::new(std::sync::Mutex::new(#cache_create))).clone(); drop(locks); let mut cache = lock.lock().unwrap(); - } - } else { - quote! { + }, + _ => quote! { let mut cache = #cache_ident.lock().unwrap(); - } + }, }; function_no_cache = quote! { @@ -271,14 +267,13 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { let result = #no_cache_fn_ident(#(#input_names),*); }; - ty = if args.sync_writes_by_key { - quote! { + ty = match args.sync_writes { + SyncWriteMode::ByKey => quote! { #visibility static #cache_ident: ::cached::once_cell::sync::Lazy>>>> = ::cached::once_cell::sync::Lazy::new(|| std::sync::Mutex::new(std::collections::HashMap::new())); - } - } else { - quote! { + }, + _ => quote! { #visibility static #cache_ident: ::cached::once_cell::sync::Lazy> = ::cached::once_cell::sync::Lazy::new(|| std::sync::Mutex::new(#cache_create)); - } + }, } } @@ -290,7 +285,7 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { #set_cache_and_return }; - let do_set_return_block = if args.sync_writes_by_key || args.sync_writes { + let do_set_return_block = if args.sync_writes != SyncWriteMode::Disabled { quote! { #lock if let Some(result) = cache.cache_get(&key) { diff --git a/cached_proc_macro/src/once.rs b/cached_proc_macro/src/once.rs index 70d5617..fd2ac07 100644 --- a/cached_proc_macro/src/once.rs +++ b/cached_proc_macro/src/once.rs @@ -6,6 +6,13 @@ use quote::quote; use syn::spanned::Spanned; use syn::{parse_macro_input, Ident, ItemFn, ReturnType}; +#[derive(Debug, Default, FromMeta)] +enum SyncWriteMode { + #[default] + Disabled, + Default, +} + #[derive(FromMeta)] struct OnceMacroArgs { #[darling(default)] @@ -13,7 +20,7 @@ struct OnceMacroArgs { #[darling(default)] time: Option, #[darling(default)] - sync_writes: bool, + sync_writes: SyncWriteMode, #[darling(default)] result: bool, #[darling(default)] @@ -220,8 +227,8 @@ pub fn once(args: TokenStream, input: TokenStream) -> TokenStream { } }; - let do_set_return_block = if args.sync_writes { - quote! { + let do_set_return_block = match args.sync_writes { + SyncWriteMode::Default => quote! { #r_lock_return_cache_block #w_lock if let Some(result) = &*cached { @@ -229,14 +236,13 @@ pub fn once(args: TokenStream, input: TokenStream) -> TokenStream { } #function_call #set_cache_and_return - } - } else { - quote! { + }, + SyncWriteMode::Disabled => quote! { #r_lock_return_cache_block #function_call #w_lock #set_cache_and_return - } + }, }; let signature_no_muts = get_mut_signature(signature); diff --git a/tests/cached.rs b/tests/cached.rs index a2a26bc..0d33e04 100644 --- a/tests/cached.rs +++ b/tests/cached.rs @@ -848,7 +848,7 @@ async fn test_only_cached_option_once_per_second_a() { /// to return the cached result of the one call instead of all /// concurrently un-cached tasks executing and writing concurrently. #[cfg(feature = "async")] -#[once(time = 2, sync_writes = true)] +#[once(time = 2, sync_writes = "default")] async fn only_cached_once_per_second_sync_writes(s: String) -> Vec { vec![s] } @@ -862,7 +862,7 @@ async fn test_only_cached_once_per_second_sync_writes() { assert_eq!(a.await.unwrap(), b.await.unwrap()); } -#[cached(time = 2, sync_writes = true, key = "u32", convert = "{ 1 }")] +#[cached(time = 2, sync_writes = "default", key = "u32", convert = "{ 1 }")] fn cached_sync_writes(s: String) -> Vec { vec![s] } @@ -881,7 +881,7 @@ fn test_cached_sync_writes() { } #[cfg(feature = "async")] -#[cached(time = 2, sync_writes = true, key = "u32", convert = "{ 1 }")] +#[cached(time = 2, sync_writes = "default", key = "u32", convert = "{ 1 }")] async fn cached_sync_writes_a(s: String) -> Vec { vec![s] } @@ -898,7 +898,7 @@ async fn test_cached_sync_writes_a() { assert_eq!(a, c.await.unwrap()); } -#[cached(time = 2, sync_writes_by_key = true, key = "u32", convert = "{ 1 }")] +#[cached(time = 2, sync_writes = "by_key", key = "u32", convert = "{ 1 }")] fn cached_sync_writes_by_key(s: String) -> Vec { sleep(Duration::new(1, 0)); vec![s] @@ -919,7 +919,7 @@ fn test_cached_sync_writes_by_key() { #[cfg(feature = "async")] #[cached( time = 5, - sync_writes_by_key = true, + sync_writes = "by_key", key = "String", convert = r#"{ format!("{}", s) }"# )] @@ -942,7 +942,7 @@ async fn test_cached_sync_writes_by_key_a() { } #[cfg(feature = "async")] -#[once(sync_writes = true)] +#[once(sync_writes = "default")] async fn once_sync_writes_a(s: &tokio::sync::Mutex) -> String { let mut guard = s.lock().await; let results: String = (*guard).clone().to_string();