diff --git a/Cargo.lock b/Cargo.lock index a8d5b032..767076f4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -62,9 +62,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.86" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +checksum = "10f00e1f6e58a40e807377c75c6a7f97bf9044fab57816f2414e6f5f4499d7b8" [[package]] name = "autocfg" @@ -631,6 +631,7 @@ dependencies = [ "log", "p3-field", "pilout", + "proofman-macros", "proofman-starks-lib-c", "proofman-util", "serde", @@ -648,6 +649,15 @@ dependencies = [ "proofman-starks-lib-c", ] +[[package]] +name = "proofman-macros" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "proofman-starks-lib-c" version = "0.1.0" @@ -825,18 +835,18 @@ checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "serde" -version = "1.0.209" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09" +checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.209" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170" +checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 17a9ff62..1fe8a80a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "hints", "pilout", "proofman", + "macros", "provers/stark", "provers/starks-lib-c", "transcript", @@ -23,3 +24,4 @@ log = { version = "0.4", default-features = false } env_logger = "0.11" p3-goldilocks = { git = "https://github.com/Plonky3/Plonky3.git", rev = "c3d754ef77b9fce585b46b972af751fe6e7a9803" } p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "c3d754ef77b9fce585b46b972af751fe6e7a9803" } +proofman-macros = { path = "macros", version = "0.1.0" } diff --git a/cli/assets/templates/pil_helpers_trace.rs.tt b/cli/assets/templates/pil_helpers_trace.rs.tt index f06e5a4e..b0df01e2 100644 --- a/cli/assets/templates/pil_helpers_trace.rs.tt +++ b/cli/assets/templates/pil_helpers_trace.rs.tt @@ -1,6 +1,7 @@ // WARNING: This file has been autogenerated from the PILOUT file. // Manual modifications are not recommended and may be overwritten. -use proofman_common::trace; +pub use proofman_macros::trace; +use proofman_common as common; {{ for air_group in air_groups }} {{ for air in air_group.airs }} diff --git a/common/Cargo.toml b/common/Cargo.toml index ee13a6c7..66e77389 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -8,8 +8,9 @@ serde = { version = "1.0.130", features = ["derive"] } serde_json = "1.0.68" serde_derive = "1.0.196" pilout = { path = "../pilout" } -log = { workspace = true } +log.workspace = true transcript = { path = "../transcript" } -p3-field = { workspace = true } +p3-field.workspace = true +proofman-macros.workspace = true proofman-util = { path = "../util" } proofman-starks-lib-c = { path = "../provers/starks-lib-c" } diff --git a/common/src/trace.rs b/common/src/trace.rs index b5db8839..1ca6e89a 100644 --- a/common/src/trace.rs +++ b/common/src/trace.rs @@ -3,236 +3,165 @@ pub trait Trace: Send { fn get_buffer_ptr(&mut self) -> *mut u8; } -#[macro_export] -macro_rules! trace { - ($row_struct_name:ident, $trace_struct_name:ident<$generic:ident> { - $( $field_name:ident : $field_type:ty ),* $(,)? - }) => { - // Define the row structure (Main0RowTrace) - #[allow(dead_code)] - #[derive(Debug, Clone, Copy, Default)] - pub struct $row_struct_name<$generic> { - $( pub $field_name: $field_type ),* - } - - - impl<$generic: Copy> $row_struct_name<$generic> { - // The size of each row in terms of the number of fields - pub const ROW_SIZE: usize = 0 $(+ trace!(@count_elements $field_type))*; - } - - // Define the trace structure (Main0Trace) that manages the row structure - pub struct $trace_struct_name<'a, $generic> { - pub buffer: Option>, - pub slice_trace: &'a mut [$row_struct_name<$generic>], - num_rows: usize, - } - - impl<'a, $generic: Default + Clone + Copy> $trace_struct_name<'a, $generic> { - // Constructor for creating a new buffer - pub fn new(num_rows: usize) -> Self { - // PRECONDITIONS - // num_rows must be greater than or equal to 2 - assert!(num_rows >= 2); - // num_rows must be a power of 2 - assert!(num_rows & (num_rows - 1) == 0); - - let buffer = vec![$generic::default(); num_rows * $row_struct_name::<$generic>::ROW_SIZE]; - - let slice_trace = unsafe { - std::slice::from_raw_parts_mut(buffer.as_ptr() as *mut $row_struct_name<$generic>, num_rows) - }; - - $trace_struct_name { buffer: Some(buffer), slice_trace, num_rows } - } - - // Constructor to map over an external buffer - pub fn map_buffer(external_buffer: &'a mut [$generic], num_rows: usize, offset: usize) -> Result> { - // PRECONDITIONS - // num_rows must be greater than or equal to 2 - assert!(num_rows >= 2); - // num_rows must be a power of 2 - assert!(num_rows & (num_rows - 1) == 0); - - let start = offset; - let end = start + num_rows * $row_struct_name::<$generic>::ROW_SIZE; - - if end > external_buffer.len() { - return Err("Buffer is too small to fit the trace".into()); - } - - let slice_trace = unsafe { - std::slice::from_raw_parts_mut( - external_buffer[start..end].as_ptr() as *mut $row_struct_name<$generic>, - num_rows, - ) - }; - - Ok($trace_struct_name { - buffer: None, - slice_trace, - num_rows, - }) - } - - // Constructor to map over an external buffer - pub fn map_row_vec(external_buffer: Vec<$row_struct_name<$generic>>) -> Result> { - let num_rows = external_buffer.len().next_power_of_two(); - - // PRECONDITIONS - // num_rows must be greater than or equal to 2 - assert!(num_rows >= 2); - // num_rows must be a power of 2 - assert!(num_rows & (num_rows - 1) == 0); - - let slice_trace = unsafe { - let ptr = external_buffer.as_ptr() as *mut $row_struct_name<$generic>; - std::slice::from_raw_parts_mut(ptr, - num_rows, - ) - }; - - let buffer_f = unsafe { - Vec::from_raw_parts(external_buffer.as_ptr() as *mut $generic, num_rows * $row_struct_name::<$generic>::ROW_SIZE, num_rows * $row_struct_name::<$generic>::ROW_SIZE) - }; - - std::mem::forget(external_buffer); - - Ok($trace_struct_name { - buffer: Some(buffer_f), - slice_trace, - num_rows, - }) - } - - pub fn num_rows(&self) -> usize { - self.num_rows - } - } - - // Implement Index trait for immutable access - impl<'a, $generic> std::ops::Index for $trace_struct_name<'a, $generic> { - type Output = $row_struct_name<$generic>; - - fn index(&self, index: usize) -> &Self::Output { - &self.slice_trace[index] - } - } - - // Implement IndexMut trait for mutable access - impl<'a, $generic> std::ops::IndexMut for $trace_struct_name<'a, $generic> { - fn index_mut(&mut self, index: usize) -> &mut Self::Output { - &mut self.slice_trace[index] - } - } - - // Implement the Trace trait - impl<'a, $generic: Send > $crate::trace::Trace for $trace_struct_name<'a, $generic> { - fn num_rows(&self) -> usize { - self.num_rows - } - - fn get_buffer_ptr(&mut self) -> *mut u8 { - let buffer = self.buffer.as_mut().expect("Buffer is not available"); - buffer.as_mut_ptr() as *mut u8 - } - } - }; - - (@count_elements $elem_type:ty) => { - if std::mem::size_of::<$elem_type>() == 8 { - 1 - } else if std::mem::size_of::<$elem_type>() == std::mem::size_of::<&str>() { - trace!(@parse_string $elem_type) - } else { - 0 - } - }; - - (@parse_string $elem_type:ty) => { - // Check if the element type is a string array - if true { - let field_str = stringify!($elem_type); - let a_bytes = "[F; 2]".as_bytes(); - let b_bytes = field_str.as_bytes(); - let mut result = 0; - let mut i = 0; - let mut c = 0; - if a_bytes.len() == b_bytes.len() { - while i < b_bytes.len() { - if a_bytes[i] != b_bytes[i] { - c+=1; - } - i+=1; - } - if c == 0 { - result = 2; - } - } - result - - } else { - 0 - } - }; -} +pub use proofman_macros::trace; + #[cfg(test)] -mod tests { - // use rand::Rng; +use crate as common; + +#[test] +#[should_panic] +fn test_errors_are_launched_when_num_rows_is_invalid_1() { + let mut buffer = vec![0u8; 3]; + trace!(SimpleRow, Simple { a: F }); + let _ = Simple::map_buffer(&mut buffer, 1, 0); +} + +#[test] +#[should_panic] +fn test_errors_are_launched_when_num_rows_is_invalid_2() { + let mut buffer = vec![0u8; 3]; + trace!(SimpleRow, Simple { a: F }); + let _ = Simple::map_buffer(&mut buffer, 3, 0); +} + +#[test] +#[should_panic] +fn test_errors_are_launched_when_num_rows_is_invalid_3() { + trace!(SimpleRow, Simple { a: F }); + let _ = Simple::::new(1); +} + +#[test] +#[should_panic] +fn test_errors_are_launched_when_num_rows_is_invalid_4() { + trace!(SimpleRow, Simple { a: F }); + let _ = Simple::::new(3); +} + +#[test] +fn check() { + const OFFSET: usize = 1; + let num_rows = 8; - #[test] - fn check() { - const OFFSET: usize = 1; - let num_rows = 8; + trace!(TraceRow, MyTrace { a: F, b:F}); - trace!(TraceRow, MyTrace { a: F, b:F}); + assert_eq!(TraceRow::::ROW_SIZE, 2); - assert_eq!(TraceRow::::ROW_SIZE, 2); + let mut buffer = vec![0usize; num_rows * TraceRow::::ROW_SIZE + OFFSET]; + let trace = MyTrace::map_buffer(&mut buffer, num_rows, OFFSET); + let mut trace = trace.unwrap(); - let mut buffer = vec![0usize; num_rows * TraceRow::::ROW_SIZE + OFFSET]; - let trace = MyTrace::map_buffer(&mut buffer, num_rows, OFFSET); - let mut trace = trace.unwrap(); + // Set values + for i in 0..num_rows { + trace[i].a = i; + trace[i].b = i * 10; + } + + // Check values + for i in 0..num_rows { + assert_eq!(trace[i].a, i); + assert_eq!(trace[i].b, i * 10); + } +} + +#[test] +fn check_array() { + let num_rows = 8; - // Set values - for i in 0..num_rows { - trace[i].a = i; - trace[i].b = i * 10; - } + trace!(TraceRow, MyTrace { a: F, b: [F; 3], c: F }); - // Check values - for i in 0..num_rows { - assert_eq!(trace[i].a, i); - assert_eq!(trace[i].b, i * 10); - } + assert_eq!(TraceRow::::ROW_SIZE, 5); + let mut buffer = vec![0usize; num_rows * TraceRow::::ROW_SIZE]; + let trace = MyTrace::map_buffer(&mut buffer, num_rows, 0); + let mut trace = trace.unwrap(); + + // Set values + for i in 0..num_rows { + trace[i].a = i; + trace[i].b[0] = i * 10; + trace[i].b[1] = i * 20; + trace[i].b[2] = i * 30; + trace[i].c = i * 40; } - #[test] - #[should_panic] - fn test_errors_are_launched_when_num_rows_is_invalid_1() { - let mut buffer = vec![0u8; 3]; - trace!(SimpleRow, Simple { a: F }); - let _ = Simple::map_buffer(&mut buffer, 1, 0); + // Check values + for i in 0..num_rows { + assert_eq!(buffer[i * TraceRow::::ROW_SIZE], i); + assert_eq!(buffer[i * TraceRow::::ROW_SIZE + 1], i * 10); + assert_eq!(buffer[i * TraceRow::::ROW_SIZE + 2], i * 20); + assert_eq!(buffer[i * TraceRow::::ROW_SIZE + 3], i * 30); + assert_eq!(buffer[i * TraceRow::::ROW_SIZE + 4], i * 40); } +} - #[test] - #[should_panic] - fn test_errors_are_launched_when_num_rows_is_invalid_2() { - let mut buffer = vec![0u8; 3]; - trace!(SimpleRow, Simple { a: F }); - let _ = Simple::map_buffer(&mut buffer, 3, 0); +#[test] +fn check_multi_array() { + let num_rows = 8; + + trace!(TraceRow, MyTrace { a: [[F;3]; 2], b: F }); + + assert_eq!(TraceRow::::ROW_SIZE, 7); + let mut buffer = vec![0usize; num_rows * TraceRow::::ROW_SIZE]; + let trace = MyTrace::map_buffer(&mut buffer, num_rows, 0); + let mut trace = trace.unwrap(); + + // Set values + for i in 0..num_rows { + trace[i].a[0][0] = i; + trace[i].a[0][1] = i * 10; + trace[i].a[0][2] = i * 20; + trace[i].a[1][0] = i * 30; + trace[i].a[1][1] = i * 40; + trace[i].a[1][2] = i * 50; + trace[i].b = i + 3; } - #[test] - #[should_panic] - fn test_errors_are_launched_when_num_rows_is_invalid_3() { - trace!(SimpleRow, Simple { a: F }); - let _ = Simple::::new(1); + // Check values + for i in 0..num_rows { + assert_eq!(buffer[i * TraceRow::::ROW_SIZE], i); + assert_eq!(buffer[i * TraceRow::::ROW_SIZE + 1], i * 10); + assert_eq!(buffer[i * TraceRow::::ROW_SIZE + 2], i * 20); + assert_eq!(buffer[i * TraceRow::::ROW_SIZE + 3], i * 30); + assert_eq!(buffer[i * TraceRow::::ROW_SIZE + 4], i * 40); + assert_eq!(buffer[i * TraceRow::::ROW_SIZE + 5], i * 50); + assert_eq!(buffer[i * TraceRow::::ROW_SIZE + 6], i + 3); + } +} + +#[test] +fn check_multi_array_2() { + let num_rows = 8; + + trace!(TraceRow, MyTrace { a: [[F;3]; 2], b: F, c: [F; 2] }); + + assert_eq!(TraceRow::::ROW_SIZE, 9); + let mut buffer = vec![0usize; num_rows * TraceRow::::ROW_SIZE]; + let trace = MyTrace::map_buffer(&mut buffer, num_rows, 0); + let mut trace = trace.unwrap(); + + // Set values + for i in 0..num_rows { + trace[i].a[0][0] = i; + trace[i].a[0][1] = i * 10; + trace[i].a[0][2] = i * 20; + trace[i].a[1][0] = i * 30; + trace[i].a[1][1] = i * 40; + trace[i].a[1][2] = i * 50; + trace[i].b = i + 3; + trace[i].c[0] = i + 9; + trace[i].c[1] = i + 2; } - #[test] - #[should_panic] - fn test_errors_are_launched_when_num_rows_is_invalid_4() { - trace!(SimpleRow, Simple { a: F }); - let _ = Simple::::new(3); + // Check values + for i in 0..num_rows { + assert_eq!(buffer[i * TraceRow::::ROW_SIZE], i); + assert_eq!(buffer[i * TraceRow::::ROW_SIZE + 1], i * 10); + assert_eq!(buffer[i * TraceRow::::ROW_SIZE + 2], i * 20); + assert_eq!(buffer[i * TraceRow::::ROW_SIZE + 3], i * 30); + assert_eq!(buffer[i * TraceRow::::ROW_SIZE + 4], i * 40); + assert_eq!(buffer[i * TraceRow::::ROW_SIZE + 5], i * 50); + assert_eq!(buffer[i * TraceRow::::ROW_SIZE + 6], i + 3); + assert_eq!(buffer[i * TraceRow::::ROW_SIZE + 7], i + 9); + assert_eq!(buffer[i * TraceRow::::ROW_SIZE + 8], i + 2); } } diff --git a/examples/fibonacci-square/Cargo.toml b/examples/fibonacci-square/Cargo.toml index b4d39399..d30116fa 100644 --- a/examples/fibonacci-square/Cargo.toml +++ b/examples/fibonacci-square/Cargo.toml @@ -8,6 +8,7 @@ crate-type = ["dylib"] [dependencies] proofman-common = { path = "../../common" } +proofman-macros.workspace = true proofman = { path = "../../proofman" } # pil-std-lib = { git = "https://github.com/0xPolygonHermez/pil2-components.git", branch ="std_rust" } diff --git a/macros/Cargo.toml b/macros/Cargo.toml new file mode 100644 index 00000000..dbc57525 --- /dev/null +++ b/macros/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "proofman-macros" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "2", features = ["full"] } +quote = "1" +proc-macro2 = "1" diff --git a/macros/src/lib.rs b/macros/src/lib.rs new file mode 100644 index 00000000..da12b836 --- /dev/null +++ b/macros/src/lib.rs @@ -0,0 +1,542 @@ +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use quote::{quote, format_ident, ToTokens}; +use syn::{ + parse2, + parse::{Parse, ParseStream}, + Ident, Generics, FieldsNamed, Result, Field, Token, +}; + +#[proc_macro] +pub fn trace(input: TokenStream) -> TokenStream { + match trace_impl(input.into()) { + Ok(tokens) => tokens.into(), + Err(e) => e.to_compile_error().into(), + } +} + +fn trace_impl(input: TokenStream2) -> Result { + let parsed_input: ParsedTraceInput = parse2(input)?; + + let row_struct_name = parsed_input.row_struct_name; + let trace_struct_name = parsed_input.struct_name; + let generics = parsed_input.generics.params; + let fields = parsed_input.fields; + + // Calculate ROW_SIZE based on the field types + let row_size = fields + .named + .iter() + .map(|field| calculate_field_size_literal(&field.ty)) + .collect::>>()? + .into_iter() + .sum::(); + + // Generate row struct + let field_definitions = fields.named.iter().map(|field| { + let Field { ident, ty, .. } = field; + quote! { pub #ident: #ty, } + }); + + let row_struct = quote! { + #[repr(C)] + #[derive(Debug, Clone, Copy, Default)] + pub struct #row_struct_name<#generics> { + #(#field_definitions)* + } + + impl<#generics: Copy> #row_struct_name<#generics> { + pub const ROW_SIZE: usize = #row_size; + } + }; + + // Generate trace struct + let trace_struct = quote! { + pub struct #trace_struct_name<'a, #generics> { + pub buffer: Option>, + pub slice_trace: &'a mut [#row_struct_name<#generics>], + num_rows: usize, + } + + impl<'a, #generics: Default + Clone + Copy> #trace_struct_name<'a, #generics> { + pub fn new(num_rows: usize) -> Self { + assert!(num_rows >= 2); + assert!(num_rows & (num_rows - 1) == 0); + + let buffer = vec![#generics::default(); num_rows * #row_struct_name::<#generics>::ROW_SIZE]; + let slice_trace = unsafe { + std::slice::from_raw_parts_mut( + buffer.as_ptr() as *mut #row_struct_name<#generics>, + num_rows, + ) + }; + + #trace_struct_name { + buffer: Some(buffer), + slice_trace, + num_rows, + } + } + + pub fn map_buffer( + external_buffer: &'a mut [#generics], + num_rows: usize, + offset: usize, + ) -> Result> { + assert!(num_rows >= 2); + assert!(num_rows & (num_rows - 1) == 0); + + let start = offset; + let end = start + num_rows * #row_struct_name::<#generics>::ROW_SIZE; + + if end > external_buffer.len() { + return Err("Buffer is too small to fit the trace".into()); + } + + let slice_trace = unsafe { + std::slice::from_raw_parts_mut( + external_buffer[start..end].as_ptr() as *mut #row_struct_name<#generics>, + num_rows, + ) + }; + + Ok(#trace_struct_name { + buffer: None, + slice_trace, + num_rows, + }) + } + + pub fn map_row_vec( + external_buffer: Vec<#row_struct_name<#generics>>, + ) -> Result> { + let num_rows = external_buffer.len().next_power_of_two(); + assert!(num_rows >= 2); + assert!(num_rows & (num_rows - 1) == 0); + + let slice_trace = unsafe { + let ptr = external_buffer.as_ptr() as *mut #row_struct_name<#generics>; + std::slice::from_raw_parts_mut( + ptr, + num_rows, + ) + }; + + let buffer_f = unsafe { + Vec::from_raw_parts( + external_buffer.as_ptr() as *mut #generics, + num_rows * #row_struct_name::<#generics>::ROW_SIZE, + num_rows * #row_struct_name::<#generics>::ROW_SIZE, + ) + }; + + std::mem::forget(external_buffer); + + Ok(#trace_struct_name { + buffer: Some(buffer_f), + slice_trace, + num_rows, + }) + } + + pub fn num_rows(&self) -> usize { + self.num_rows + } + } + + impl<'a, #generics> std::ops::Index for #trace_struct_name<'a, #generics> { + type Output = #row_struct_name<#generics>; + + fn index(&self, index: usize) -> &Self::Output { + &self.slice_trace[index] + } + } + + impl<'a, #generics> std::ops::IndexMut for #trace_struct_name<'a, #generics> { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.slice_trace[index] + } + } + + impl<'a, #generics: Send> common::trace::Trace for #trace_struct_name<'a, #generics> { + fn num_rows(&self) -> usize { + self.num_rows + } + + fn get_buffer_ptr(&mut self) -> *mut u8 { + let buffer = self.buffer.as_mut().expect("Buffer is not available"); + buffer.as_mut_ptr() as *mut u8 + } + } + }; + + Ok(quote! { + #row_struct + #trace_struct + }) +} + +// A struct to handle parsing the input and all the syntactic variations +struct ParsedTraceInput { + row_struct_name: Ident, + struct_name: Ident, + generics: Generics, + fields: FieldsNamed, +} + +impl Parse for ParsedTraceInput { + fn parse(input: ParseStream) -> Result { + let lookahead = input.lookahead1(); + let row_struct_name; + + // Handle explicit or implicit row struct names + if lookahead.peek(Ident) && input.peek2(Token![,]) { + row_struct_name = Some(input.parse::()?); + input.parse::()?; // Skip comma after explicit row name + } else { + row_struct_name = None; + } + + let struct_name = input.parse::()?; + let row_struct_name = row_struct_name.unwrap_or_else(|| format_ident!("{}Row", struct_name)); + + let generics: Generics = input.parse()?; + let fields: FieldsNamed = input.parse()?; + + Ok(ParsedTraceInput { row_struct_name, struct_name, generics, fields }) + } +} + +// Calculate the size of a field based on its type and return it as a Result +fn calculate_field_size_literal(field_type: &syn::Type) -> Result { + match field_type { + // Handle arrays with multiple dimensions + syn::Type::Array(type_array) => { + let len = type_array.len.to_token_stream().to_string().parse::().map_err(|e| { + syn::Error::new_spanned(&type_array.len, format!("Failed to parse array length: {}", e)) + })?; + let elem_size = calculate_field_size_literal(&type_array.elem)?; + Ok(len * elem_size) + } + // For simple types, the size is 1 + _ => Ok(1), + } +} + +#[test] +fn test_trace_macro_generates_default_row_struct() { + let input = quote! { + Simple { a: F, b: F } + }; + + let expected = quote! { + #[repr(C)] + #[derive(Debug, Clone, Copy, Default)] + pub struct SimpleRow { + pub a: F, + pub b: F, + } + impl SimpleRow { + pub const ROW_SIZE: usize = 2usize; + } + pub struct Simple<'a, F> { + pub buffer: Option>, + pub slice_trace: &'a mut [SimpleRow], + num_rows: usize, + } + impl<'a, F: Default + Clone + Copy> Simple<'a, F> { + pub fn new(num_rows: usize) -> Self { + assert!(num_rows >= 2); + assert!(num_rows & (num_rows - 1) == 0); + let buffer = vec![F::default(); num_rows * SimpleRow::::ROW_SIZE]; + let slice_trace = unsafe { + std::slice::from_raw_parts_mut( + buffer.as_ptr() as *mut SimpleRow, + num_rows, + ) + }; + Simple { + buffer: Some(buffer), + slice_trace, + num_rows, + } + } + pub fn map_buffer( + external_buffer: &'a mut [F], + num_rows: usize, + offset: usize, + ) -> Result> { + assert!(num_rows >= 2); + assert!(num_rows & (num_rows - 1) == 0); + let start = offset; + let end = start + num_rows * SimpleRow::::ROW_SIZE; + if end > external_buffer.len() { + return Err("Buffer is too small to fit the trace".into()); + } + let slice_trace = unsafe { + std::slice::from_raw_parts_mut( + external_buffer[start..end].as_ptr() as *mut SimpleRow, + num_rows, + ) + }; + Ok(Simple { + buffer: None, + slice_trace, + num_rows, + }) + } + pub fn map_row_vec( + external_buffer: Vec>, + ) -> Result> { + let num_rows = external_buffer.len().next_power_of_two(); + assert!(num_rows >= 2); + assert!(num_rows & (num_rows - 1) == 0); + let slice_trace = unsafe { + let ptr = external_buffer.as_ptr() as *mut SimpleRow; + std::slice::from_raw_parts_mut( + ptr, + num_rows, + ) + }; + let buffer_f = unsafe { + Vec::from_raw_parts( + external_buffer.as_ptr() as *mut F, + num_rows * SimpleRow::::ROW_SIZE, + num_rows * SimpleRow::::ROW_SIZE, + ) + }; + std::mem::forget(external_buffer); + Ok(Simple { + buffer: Some(buffer_f), + slice_trace, + num_rows, + }) + } + pub fn num_rows(&self) -> usize { + self.num_rows + } + } + impl<'a, F> std::ops::Index for Simple<'a, F> { + type Output = SimpleRow; + fn index(&self, index: usize) -> &Self::Output { + &self.slice_trace[index] + } + } + impl<'a, F> std::ops::IndexMut for Simple<'a, F> { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.slice_trace[index] + } + } + impl<'a, F: Send> common::trace::Trace for Simple<'a, F> { + fn num_rows(&self) -> usize { + self.num_rows + } + fn get_buffer_ptr(&mut self) -> *mut u8 { + let buffer = self.buffer.as_mut().expect("Buffer is not available"); + buffer.as_mut_ptr() as *mut u8 + } + } + + + }; + let generated = trace_impl(input.into()).unwrap(); + assert_eq!(generated.to_string(), expected.into_token_stream().to_string()); +} + +#[test] +fn test_trace_macro_with_explicit_row_struct_name() { + let input = quote! { + SimpleRow, Simple { a: F, b: F } + }; + + let expected = quote! { + #[repr(C)] + #[derive(Debug, Clone, Copy, Default)] + pub struct SimpleRow { + pub a: F, + pub b: F, + } + + impl SimpleRow { + pub const ROW_SIZE: usize = 2usize; + } + + pub struct Simple<'a, F> { + pub buffer: Option>, + pub slice_trace: &'a mut [SimpleRow], + num_rows: usize, + } + + impl<'a, F: Default + Clone + Copy> Simple<'a, F> { + pub fn new(num_rows: usize) -> Self { + assert!(num_rows >= 2); + assert!(num_rows & (num_rows - 1) == 0); + let buffer = vec![F::default(); num_rows * SimpleRow::::ROW_SIZE]; + let slice_trace = unsafe { + std::slice::from_raw_parts_mut( + buffer.as_ptr() as *mut SimpleRow, + num_rows, + ) + }; + Simple { + buffer: Some(buffer), + slice_trace, + num_rows, + } + } + + pub fn map_buffer( + external_buffer: &'a mut [F], + num_rows: usize, + offset: usize, + ) -> Result> { + assert!(num_rows >= 2); + assert!(num_rows & (num_rows - 1) == 0); + let start = offset; + let end = start + num_rows * SimpleRow::::ROW_SIZE; + if end > external_buffer.len() { + return Err("Buffer is too small to fit the trace".into()); + } + let slice_trace = unsafe { + std::slice::from_raw_parts_mut( + external_buffer[start..end].as_ptr() as *mut SimpleRow, + num_rows, + ) + }; + Ok(Simple { + buffer: None, + slice_trace, + num_rows, + }) + } + + pub fn map_row_vec( + external_buffer: Vec>, + ) -> Result> { + let num_rows = external_buffer.len().next_power_of_two(); + assert!(num_rows >= 2); + assert!(num_rows & (num_rows - 1) == 0); + let slice_trace = unsafe { + let ptr = external_buffer.as_ptr() as *mut SimpleRow; + std::slice::from_raw_parts_mut( + ptr, + num_rows, + ) + }; + let buffer_f = unsafe { + Vec::from_raw_parts( + external_buffer.as_ptr() as *mut F, + num_rows * SimpleRow::::ROW_SIZE, num_rows * SimpleRow::::ROW_SIZE, + ) + }; + std::mem::forget(external_buffer); + Ok(Simple { + buffer: Some(buffer_f), + slice_trace, num_rows, + }) + } + + pub fn num_rows(&self) -> usize { + self.num_rows + } + } + + impl<'a, F> std::ops::Index for Simple<'a, F> { + type Output = SimpleRow; + + fn index(&self, index: usize) -> &Self::Output { + &self.slice_trace[index] + } + } + + impl<'a, F> std::ops::IndexMut for Simple<'a, F> { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.slice_trace[index] + } + } + + impl<'a, F: Send> common::trace::Trace for Simple<'a, F> { + fn num_rows(&self) -> usize { + self.num_rows + } + + fn get_buffer_ptr(&mut self) -> *mut u8 { + let buffer = self.buffer.as_mut().expect("Buffer is not available"); + buffer.as_mut_ptr() as *mut u8 + } + } + }; + + let generated = trace_impl(input.into()).unwrap(); + assert_eq!(generated.to_string(), expected.into_token_stream().to_string()); +} + +#[test] +fn test_parsing_01() { + let input = quote! { + TraceRow, MyTrace { a: F, b: F } + }; + let parsed: ParsedTraceInput = parse2(input).unwrap(); + assert_eq!(parsed.row_struct_name, "TraceRow"); + assert_eq!(parsed.struct_name, "MyTrace"); +} + +#[test] +fn test_parsing_02() { + let input = quote! { + SimpleRow, Simple { a: F } + }; + let parsed: ParsedTraceInput = parse2(input).unwrap(); + assert_eq!(parsed.row_struct_name, "SimpleRow"); + assert_eq!(parsed.struct_name, "Simple"); +} + +#[test] +fn test_parsing_03() { + let input = quote! { + Simple { a: F } + }; + let parsed: ParsedTraceInput = parse2(input).unwrap(); + assert_eq!(parsed.row_struct_name, "SimpleRow"); + assert_eq!(parsed.struct_name, "Simple"); +} + +#[test] +fn test_simple_type_size() { + // A simple type like `F` should return size 1 + let ty: syn::Type = syn::parse_quote! { F }; + let size = calculate_field_size_literal(&ty).unwrap(); + assert_eq!(size, 1); +} + +#[test] +fn test_array_type_size_single_dimension() { + // An array like `[F; 3]` should return size 3 + let ty: syn::Type = syn::parse_quote! { [F; 3] }; + let size = calculate_field_size_literal(&ty).unwrap(); + assert_eq!(size, 3); +} + +#[test] +fn test_array_type_size_multi_dimension() { + // A multi-dimensional array like `[[F; 3]; 2]` should return size 6 (2 * 3) + let ty: syn::Type = syn::parse_quote! { [[F; 3]; 2] }; + let size = calculate_field_size_literal(&ty).unwrap(); + assert_eq!(size, 6); +} + +#[test] +fn test_nested_array_type_size() { + // A more deeply nested array like `[[[F; 2]; 3]; 4]` should return size 24 (4 * 3 * 2) + let ty: syn::Type = syn::parse_quote! { [[[F; 2]; 3]; 4] }; + let size = calculate_field_size_literal(&ty).unwrap(); + assert_eq!(size, 24); +} + +#[test] +fn test_empty_array() { + // An empty array should return size 0 + let ty: syn::Type = syn::parse_quote! { [F; 0] }; + let size = calculate_field_size_literal(&ty).unwrap(); + assert_eq!(size, 0); +}