Skip to content

Commit

Permalink
Use enum to configure sync write mode
Browse files Browse the repository at this point in the history
  • Loading branch information
raimannma committed Nov 6, 2024
1 parent b352cdd commit f70e206
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 51 deletions.
71 changes: 33 additions & 38 deletions cached_proc_macro/src/cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,18 @@ use darling::ast::NestedMeta;
use darling::FromMeta;
use proc_macro::TokenStream;
use quote::quote;
use std::cmp::PartialEq;
use syn::spanned::Spanned;
use syn::{parse_macro_input, parse_str, Block, Ident, ItemFn, ReturnType, Type};

#[derive(Debug, Default, FromMeta, Eq, PartialEq)]
enum SyncWriteMode {
#[default]
Disabled,
Default,
ByKey,
}

#[derive(FromMeta)]
struct MacroArgs {
#[darling(default)]
Expand All @@ -27,9 +36,7 @@ struct MacroArgs {
#[darling(default)]
option: bool,
#[darling(default)]
sync_writes: bool,
#[darling(default)]
sync_writes_by_key: bool,
sync_writes: SyncWriteMode,
#[darling(default)]
with_cached_flag: bool,
#[darling(default)]
Expand Down Expand Up @@ -192,16 +199,8 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream {
_ => panic!("the result and option attributes are mutually exclusive"),
};

if args.result_fallback && args.sync_writes {
panic!("the result_fallback and sync_writes attributes are mutually exclusive");
}

if args.result_fallback && args.sync_writes_by_key {
panic!("the result_fallback and sync_writes_by_key attributes are mutually exclusive");
}

if args.sync_writes && args.sync_writes_by_key {
panic!("the sync_writes and sync_writes_by_key attributes are mutually exclusive");
if args.result_fallback && args.sync_writes != SyncWriteMode::Disabled {
panic!("result_fallback and sync_writes are mutually exclusive");
}

let set_cache_and_return = quote! {
Expand All @@ -216,20 +215,19 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream {
let function_call;
let ty;
if asyncness.is_some() {
lock = if args.sync_writes_by_key {
quote! {
lock = match args.sync_writes {
SyncWriteMode::ByKey => quote! {
let mut locks = #cache_ident.lock().await;
let lock = locks
.entry(key.clone())
.or_insert_with(|| std::sync::Arc::new(::cached::async_sync::Mutex::new(#cache_create)))
.clone();
drop(locks);
let mut cache = lock.lock().await;
}
} else {
quote! {
},
_ => quote! {
let mut cache = #cache_ident.lock().await;
}
},
};

function_no_cache = quote! {
Expand All @@ -240,27 +238,25 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream {
let result = #no_cache_fn_ident(#(#input_names),*).await;
};

ty = if args.sync_writes_by_key {
quote! {
ty = match args.sync_writes {
SyncWriteMode::ByKey => quote! {
#visibility static #cache_ident: ::cached::once_cell::sync::Lazy<::cached::async_sync::Mutex<std::collections::HashMap<#cache_key_ty, std::sync::Arc<::cached::async_sync::Mutex<#cache_ty>>>>> = ::cached::once_cell::sync::Lazy::new(|| ::cached::async_sync::Mutex::new(std::collections::HashMap::new()));
}
} else {
quote! {
},
_ => quote! {
#visibility static #cache_ident: ::cached::once_cell::sync::Lazy<::cached::async_sync::Mutex<#cache_ty>> = ::cached::once_cell::sync::Lazy::new(|| ::cached::async_sync::Mutex::new(#cache_create));
}
},
};
} else {
lock = if args.sync_writes_by_key {
quote! {
lock = match args.sync_writes {
SyncWriteMode::ByKey => quote! {
let mut locks = #cache_ident.lock().unwrap();
let lock = locks.entry(key.clone()).or_insert_with(|| std::sync::Arc::new(std::sync::Mutex::new(#cache_create))).clone();
drop(locks);
let mut cache = lock.lock().unwrap();
}
} else {
quote! {
},
_ => quote! {
let mut cache = #cache_ident.lock().unwrap();
}
},
};

function_no_cache = quote! {
Expand All @@ -271,14 +267,13 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream {
let result = #no_cache_fn_ident(#(#input_names),*);
};

ty = if args.sync_writes_by_key {
quote! {
ty = match args.sync_writes {
SyncWriteMode::ByKey => quote! {
#visibility static #cache_ident: ::cached::once_cell::sync::Lazy<std::sync::Mutex<std::collections::HashMap<#cache_key_ty, std::sync::Arc<std::sync::Mutex<#cache_ty>>>>> = ::cached::once_cell::sync::Lazy::new(|| std::sync::Mutex::new(std::collections::HashMap::new()));
}
} else {
quote! {
},
_ => quote! {
#visibility static #cache_ident: ::cached::once_cell::sync::Lazy<std::sync::Mutex<#cache_ty>> = ::cached::once_cell::sync::Lazy::new(|| std::sync::Mutex::new(#cache_create));
}
},
}
}

Expand All @@ -290,7 +285,7 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream {
#set_cache_and_return
};

let do_set_return_block = if args.sync_writes_by_key || args.sync_writes {
let do_set_return_block = if args.sync_writes != SyncWriteMode::Disabled {
quote! {
#lock
if let Some(result) = cache.cache_get(&key) {
Expand Down
20 changes: 13 additions & 7 deletions cached_proc_macro/src/once.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,21 @@ use quote::quote;
use syn::spanned::Spanned;
use syn::{parse_macro_input, Ident, ItemFn, ReturnType};

#[derive(Debug, Default, FromMeta)]
enum SyncWriteMode {
#[default]
Disabled,
Default,
}

#[derive(FromMeta)]
struct OnceMacroArgs {
#[darling(default)]
name: Option<String>,
#[darling(default)]
time: Option<u64>,
#[darling(default)]
sync_writes: bool,
sync_writes: SyncWriteMode,
#[darling(default)]
result: bool,
#[darling(default)]
Expand Down Expand Up @@ -220,23 +227,22 @@ pub fn once(args: TokenStream, input: TokenStream) -> TokenStream {
}
};

let do_set_return_block = if args.sync_writes {
quote! {
let do_set_return_block = match args.sync_writes {
SyncWriteMode::Default => quote! {
#r_lock_return_cache_block
#w_lock
if let Some(result) = &*cached {
#return_cache_block
}
#function_call
#set_cache_and_return
}
} else {
quote! {
},
SyncWriteMode::Disabled => quote! {
#r_lock_return_cache_block
#function_call
#w_lock
#set_cache_and_return
}
},
};

let signature_no_muts = get_mut_signature(signature);
Expand Down
12 changes: 6 additions & 6 deletions tests/cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,7 @@ async fn test_only_cached_option_once_per_second_a() {
/// to return the cached result of the one call instead of all
/// concurrently un-cached tasks executing and writing concurrently.
#[cfg(feature = "async")]
#[once(time = 2, sync_writes = true)]
#[once(time = 2, sync_writes = "default")]
async fn only_cached_once_per_second_sync_writes(s: String) -> Vec<String> {
vec![s]
}
Expand All @@ -862,7 +862,7 @@ async fn test_only_cached_once_per_second_sync_writes() {
assert_eq!(a.await.unwrap(), b.await.unwrap());
}

#[cached(time = 2, sync_writes = true, key = "u32", convert = "{ 1 }")]
#[cached(time = 2, sync_writes = "default", key = "u32", convert = "{ 1 }")]
fn cached_sync_writes(s: String) -> Vec<String> {
vec![s]
}
Expand All @@ -881,7 +881,7 @@ fn test_cached_sync_writes() {
}

#[cfg(feature = "async")]
#[cached(time = 2, sync_writes = true, key = "u32", convert = "{ 1 }")]
#[cached(time = 2, sync_writes = "default", key = "u32", convert = "{ 1 }")]
async fn cached_sync_writes_a(s: String) -> Vec<String> {
vec![s]
}
Expand All @@ -898,7 +898,7 @@ async fn test_cached_sync_writes_a() {
assert_eq!(a, c.await.unwrap());
}

#[cached(time = 2, sync_writes_by_key = true, key = "u32", convert = "{ 1 }")]
#[cached(time = 2, sync_writes = "by_key", key = "u32", convert = "{ 1 }")]
fn cached_sync_writes_by_key(s: String) -> Vec<String> {
sleep(Duration::new(1, 0));
vec![s]
Expand All @@ -919,7 +919,7 @@ fn test_cached_sync_writes_by_key() {
#[cfg(feature = "async")]
#[cached(
time = 5,
sync_writes_by_key = true,
sync_writes = "by_key",
key = "String",
convert = r#"{ format!("{}", s) }"#
)]
Expand All @@ -942,7 +942,7 @@ async fn test_cached_sync_writes_by_key_a() {
}

#[cfg(feature = "async")]
#[once(sync_writes = true)]
#[once(sync_writes = "default")]
async fn once_sync_writes_a(s: &tokio::sync::Mutex<String>) -> String {
let mut guard = s.lock().await;
let results: String = (*guard).clone().to_string();
Expand Down

0 comments on commit f70e206

Please sign in to comment.