From 19c8b55f90ee55a9514edb2e22e7994fdf30b631 Mon Sep 17 00:00:00 2001 From: Ethan Cemer Date: Fri, 8 Sep 2023 19:12:54 -0500 Subject: [PATCH 1/4] *added wasm binding + test --- src/wasm.rs | 47 +++++++++++++++++++++++++++++++++++++++++++++++ tests/wasm.rs | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/src/wasm.rs b/src/wasm.rs index 96c58332f..3c4cbf1ee 100644 --- a/src/wasm.rs +++ b/src/wasm.rs @@ -70,6 +70,43 @@ pub fn floatToVecU64(input: f64, scale: u32) -> wasm_bindgen::Clamped> { wasm_bindgen::Clamped(serde_json::to_vec(&vec).unwrap()) } +/// Converts a buffer to 4 u64s representing a fixed point field element +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn bufferToVecOfVecU64(buffer: wasm_bindgen::Clamped>) -> wasm_bindgen::Clamped> { + // Convert the buffer to a slice + let buffer: &[u8] = &buffer; + + // Divide the buffer into chunks of 64 bytes + let chunks = buffer.chunks_exact(16); + + // Get the remainder + let remainder = chunks.remainder(); + + // Add 0s to the remainder to make it 64 bytes + let mut remainder = remainder.to_vec(); + remainder.resize(16, 0); + + // Convert the Vec to [u8; 16] + let remainder_array: [u8; 16] = remainder.try_into().expect("Slice must be of length 16"); + + // Collect chunks into a Vec<[u8; 16]>. + let mut chunks: Vec<[u8; 16]> = chunks.map(|slice| { + let array: [u8; 16] = slice.try_into().expect("Slice must be of length 16"); + array + }).collect(); + + // append the remainder to the chunks + chunks.push(remainder_array); + + // Convert each chunk to a field element + let field_elements: Vec = chunks.iter().map( + |x| PrimeField::from_u128(u8_array_to_u128_le(*x)) + ).collect(); + + wasm_bindgen::Clamped(serde_json::to_vec(&field_elements).unwrap()) +} + /// Generate a poseidon hash in browser. Input message #[wasm_bindgen] #[allow(non_snake_case)] @@ -263,3 +300,13 @@ where let pk = keygen_pk(params, vk, &empty_circuit)?; Ok(pk) } + +/// +pub fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 { + let mut n: u128 = 0; + for &b in arr.iter().rev() { + n <<= 8; + n |= b as u128; + } + n +} diff --git a/tests/wasm.rs b/tests/wasm.rs index d56d3fe1d..9ce0e13bd 100644 --- a/tests/wasm.rs +++ b/tests/wasm.rs @@ -11,13 +11,14 @@ mod wasm32 { use ezkl::pfsys::Snark; use ezkl::wasm::{ elgamalDecrypt, elgamalEncrypt, elgamalGenRandom, poseidonHash, prove, vecU64ToFelt, - vecU64ToFloat, vecU64ToInt, verify, genWitness + vecU64ToFloat, vecU64ToInt, verify, genWitness, bufferToVecOfVecU64, u8_array_to_u128_le }; use halo2curves::bn256::{Fr, G1Affine}; use rand::rngs::StdRng; use rand::SeedableRng; #[cfg(feature = "web")] pub use wasm_bindgen_rayon::init_thread_pool; + use snark_verifier::util::arithmetic::PrimeField; use wasm_bindgen_test::*; wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); @@ -49,6 +50,54 @@ mod wasm32 { assert_eq!(hex_string, returned_string); } } + + #[wasm_bindgen_test] + async fn verify_buffer_to_field_elements() { + + let string_high = String::from("high"); + let mut buffer = string_high.clone().into_bytes(); + let clamped = wasm_bindgen::Clamped(buffer.clone()); + + let field_elements_ser = bufferToVecOfVecU64(clamped); + + let field_elements: Vec = serde_json::from_slice(&field_elements_ser[..]).unwrap(); + + buffer.resize(16, 0); + + let reference_int = u8_array_to_u128_le(buffer.try_into().unwrap()); + + let reference_field_element_high = PrimeField::from_u128(reference_int); + + assert_eq!(field_elements[0], reference_field_element_high); + + // length 16 string (divisible by 16 so doesn't need padding) + let string_sample = String::from("a sample string!"); + let buffer = string_sample.clone().into_bytes(); + let clamped = wasm_bindgen::Clamped(buffer.clone()); + + let field_elements_ser = bufferToVecOfVecU64(clamped); + + let field_elements: Vec = serde_json::from_slice(&field_elements_ser[..]).unwrap(); + + let reference_int = u8_array_to_u128_le(buffer.try_into().unwrap()); + + let reference_field_element_sample = PrimeField::from_u128(reference_int); + + assert_eq!(field_elements[0], reference_field_element_sample); + + let string_concat = string_sample + &string_high; + + let buffer = string_concat.into_bytes(); + let clamped = wasm_bindgen::Clamped(buffer.clone()); + + let field_elements_ser = bufferToVecOfVecU64(clamped); + + let field_elements: Vec = serde_json::from_slice(&field_elements_ser[..]).unwrap(); + + assert_eq!(field_elements[0], reference_field_element_sample); + assert_eq!(field_elements[1], reference_field_element_high); + + } #[wasm_bindgen_test] async fn verify_elgamal_gen_random_wasm() { From 32ab7de80707177d837edb7a686bc57e022494d4 Mon Sep 17 00:00:00 2001 From: Ethan Cemer Date: Fri, 8 Sep 2023 20:46:50 -0500 Subject: [PATCH 2/4] *finished wasm + python implementation and tests --- src/python.rs | 58 +++++++++++++++++++++++++++++++++-- src/wasm.rs | 17 +++++----- tests/python/binding_tests.py | 13 ++++++++ 3 files changed, 77 insertions(+), 11 deletions(-) diff --git a/src/python.rs b/src/python.rs index d369333be..420bccea7 100644 --- a/src/python.rs +++ b/src/python.rs @@ -19,6 +19,7 @@ use pyo3_log; use std::str::FromStr; use std::{fs::File, path::PathBuf}; use tokio::runtime::Runtime; +use snark_verifier::util::arithmetic::PrimeField; /// pyclass containing the struct used for run_args #[pyclass] @@ -130,9 +131,59 @@ fn float_to_vecu64(input: f64, scale: u32) -> PyResult<[u64; 4]> { let int_rep = quantize_float(&input, 0.0, scale) .map_err(|_| PyIOError::new_err("Failed to quantize input"))?; let felt = i128_to_felt(int_rep); - Ok(crate::pfsys::field_to_vecu64_montgomery::< - halo2curves::bn256::Fr, - >(&felt)) + Ok(crate::pfsys::field_to_vecu64_montgomery::(&felt)) +} + +/// Converts a buffer to vector of 4 u64s representing a fixed point field element +#[pyfunction(signature = ( + buffer + ))] +fn buffer_to_vec_of_vecu64(buffer: Vec) -> PyResult> { + + fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 { + let mut n: u128 = 0; + for &b in arr.iter().rev() { + n <<= 8; + n |= b as u128; + } + n + } + + let buffer = &buffer[..]; + + // Divide the buffer into chunks of 64 bytes + let chunks = buffer.chunks_exact(16); + + // Get the remainder + let remainder = chunks.remainder(); + + // Add 0s to the remainder to make it 64 bytes + let mut remainder = remainder.to_vec(); + + // Collect chunks into a Vec<[u8; 16]>. + let mut chunks: Vec<[u8; 16]> = chunks.map(|slice| { + let array: [u8; 16] = slice.try_into().unwrap(); + array + }).collect(); + + if remainder.len() != 0 { + remainder.resize(16, 0); + // Convert the Vec to [u8; 16] + let remainder_array: [u8; 16] = remainder.try_into().unwrap(); + // append the remainder to the chunks + chunks.push(remainder_array); + } + + // Convert each chunk to a field element + let field_elements: Vec = chunks.iter().map( + |x| PrimeField::from_u128(u8_array_to_u128_le(*x)) + ).collect(); + + let field_elements: Vec = field_elements.iter().map( + |x| format!("{:?}", x) + ).collect(); + + Ok(field_elements) } /// Generates a vk from a pk for a model circuit and saves it to a file @@ -736,6 +787,7 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(vecu64_to_int, m)?)?; m.add_function(wrap_pyfunction!(vecu64_to_float, m)?)?; m.add_function(wrap_pyfunction!(float_to_vecu64, m)?)?; + m.add_function(wrap_pyfunction!(buffer_to_vec_of_vecu64, m)?)?; m.add_function(wrap_pyfunction!(gen_vk_from_pk_aggr, m)?)?; m.add_function(wrap_pyfunction!(gen_vk_from_pk_single, m)?)?; m.add_function(wrap_pyfunction!(table, m)?)?; diff --git a/src/wasm.rs b/src/wasm.rs index 3c4cbf1ee..a895338c5 100644 --- a/src/wasm.rs +++ b/src/wasm.rs @@ -70,7 +70,7 @@ pub fn floatToVecU64(input: f64, scale: u32) -> wasm_bindgen::Clamped> { wasm_bindgen::Clamped(serde_json::to_vec(&vec).unwrap()) } -/// Converts a buffer to 4 u64s representing a fixed point field element +/// Converts a buffer to vector of 4 u64s representing a fixed point field element #[wasm_bindgen] #[allow(non_snake_case)] pub fn bufferToVecOfVecU64(buffer: wasm_bindgen::Clamped>) -> wasm_bindgen::Clamped> { @@ -85,19 +85,20 @@ pub fn bufferToVecOfVecU64(buffer: wasm_bindgen::Clamped>) -> wasm_bindg // Add 0s to the remainder to make it 64 bytes let mut remainder = remainder.to_vec(); - remainder.resize(16, 0); - - // Convert the Vec to [u8; 16] - let remainder_array: [u8; 16] = remainder.try_into().expect("Slice must be of length 16"); // Collect chunks into a Vec<[u8; 16]>. let mut chunks: Vec<[u8; 16]> = chunks.map(|slice| { - let array: [u8; 16] = slice.try_into().expect("Slice must be of length 16"); + let array: [u8; 16] = slice.try_into().unwrap(); array }).collect(); - // append the remainder to the chunks - chunks.push(remainder_array); + if remainder.len() != 0 { + remainder.resize(16, 0); + // Convert the Vec to [u8; 16] + let remainder_array: [u8; 16] = remainder.try_into().unwrap(); + // append the remainder to the chunks + chunks.push(remainder_array); + } // Convert each chunk to a field element let field_elements: Vec = chunks.iter().map( diff --git a/tests/python/binding_tests.py b/tests/python/binding_tests.py index 4c7e5ac7e..a7b0be3b8 100644 --- a/tests/python/binding_tests.py +++ b/tests/python/binding_tests.py @@ -69,6 +69,19 @@ def test_field_serialization(): roundtrip_input = ezkl.vecu64_to_float(felt, scale) assert input == roundtrip_input +def test_buffer_to_vec_of_vecu64(): + """ + Test buffer_to_vec_of_vecu64 + """ + buffer = bytearray("a sample string!", 'utf-8') + felts = ezkl.buffer_to_vec_of_vecu64(buffer) + ref_felt_1 = "0x0000000000000000000000000000000021676e6972747320656c706d61732061" + assert felts == [ref_felt_1] + + buffer = bytearray("a sample string!"+"high", 'utf-8') + felts = ezkl.buffer_to_vec_of_vecu64(buffer) + ref_felt_2 = "0x0000000000000000000000000000000000000000000000000000000068676968" + assert felts == [ref_felt_1,ref_felt_2] def test_table_1l_average(): """ From 3bf75afec016b3a2d7a3f97d84f46614bb2e2491 Mon Sep 17 00:00:00 2001 From: Ethan Cemer Date: Fri, 8 Sep 2023 20:52:26 -0500 Subject: [PATCH 3/4] *changed python implementation name from buffer_to_vec_of_vecu64 to buffer_to_felt --- src/python.rs | 4 ++-- tests/python/binding_tests.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/python.rs b/src/python.rs index 420bccea7..897ebdffa 100644 --- a/src/python.rs +++ b/src/python.rs @@ -138,7 +138,7 @@ fn float_to_vecu64(input: f64, scale: u32) -> PyResult<[u64; 4]> { #[pyfunction(signature = ( buffer ))] -fn buffer_to_vec_of_vecu64(buffer: Vec) -> PyResult> { +fn buffer_to_felt(buffer: Vec) -> PyResult> { fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 { let mut n: u128 = 0; @@ -787,7 +787,7 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(vecu64_to_int, m)?)?; m.add_function(wrap_pyfunction!(vecu64_to_float, m)?)?; m.add_function(wrap_pyfunction!(float_to_vecu64, m)?)?; - m.add_function(wrap_pyfunction!(buffer_to_vec_of_vecu64, m)?)?; + m.add_function(wrap_pyfunction!(buffer_to_felt, m)?)?; m.add_function(wrap_pyfunction!(gen_vk_from_pk_aggr, m)?)?; m.add_function(wrap_pyfunction!(gen_vk_from_pk_single, m)?)?; m.add_function(wrap_pyfunction!(table, m)?)?; diff --git a/tests/python/binding_tests.py b/tests/python/binding_tests.py index a7b0be3b8..182e15f56 100644 --- a/tests/python/binding_tests.py +++ b/tests/python/binding_tests.py @@ -69,17 +69,17 @@ def test_field_serialization(): roundtrip_input = ezkl.vecu64_to_float(felt, scale) assert input == roundtrip_input -def test_buffer_to_vec_of_vecu64(): +def test_buffer_to_felt(): """ - Test buffer_to_vec_of_vecu64 + Test buffer_to_felt """ buffer = bytearray("a sample string!", 'utf-8') - felts = ezkl.buffer_to_vec_of_vecu64(buffer) + felts = ezkl.buffer_to_felt(buffer) ref_felt_1 = "0x0000000000000000000000000000000021676e6972747320656c706d61732061" assert felts == [ref_felt_1] buffer = bytearray("a sample string!"+"high", 'utf-8') - felts = ezkl.buffer_to_vec_of_vecu64(buffer) + felts = ezkl.buffer_to_felt(buffer) ref_felt_2 = "0x0000000000000000000000000000000000000000000000000000000068676968" assert felts == [ref_felt_1,ref_felt_2] From 3c21fa6f87444ffcd650243c5eb8afa26ee52e77 Mon Sep 17 00:00:00 2001 From: Ethan Cemer Date: Fri, 8 Sep 2023 20:54:20 -0500 Subject: [PATCH 4/4] *update name from felt to felts --- src/python.rs | 4 ++-- tests/python/binding_tests.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/python.rs b/src/python.rs index 897ebdffa..3be1d202e 100644 --- a/src/python.rs +++ b/src/python.rs @@ -138,7 +138,7 @@ fn float_to_vecu64(input: f64, scale: u32) -> PyResult<[u64; 4]> { #[pyfunction(signature = ( buffer ))] -fn buffer_to_felt(buffer: Vec) -> PyResult> { +fn buffer_to_felts(buffer: Vec) -> PyResult> { fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 { let mut n: u128 = 0; @@ -787,7 +787,7 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(vecu64_to_int, m)?)?; m.add_function(wrap_pyfunction!(vecu64_to_float, m)?)?; m.add_function(wrap_pyfunction!(float_to_vecu64, m)?)?; - m.add_function(wrap_pyfunction!(buffer_to_felt, m)?)?; + m.add_function(wrap_pyfunction!(buffer_to_felts, m)?)?; m.add_function(wrap_pyfunction!(gen_vk_from_pk_aggr, m)?)?; m.add_function(wrap_pyfunction!(gen_vk_from_pk_single, m)?)?; m.add_function(wrap_pyfunction!(table, m)?)?; diff --git a/tests/python/binding_tests.py b/tests/python/binding_tests.py index 182e15f56..9e5489fe5 100644 --- a/tests/python/binding_tests.py +++ b/tests/python/binding_tests.py @@ -69,17 +69,17 @@ def test_field_serialization(): roundtrip_input = ezkl.vecu64_to_float(felt, scale) assert input == roundtrip_input -def test_buffer_to_felt(): +def test_buffer_to_felts(): """ Test buffer_to_felt """ buffer = bytearray("a sample string!", 'utf-8') - felts = ezkl.buffer_to_felt(buffer) + felts = ezkl.buffer_to_felts(buffer) ref_felt_1 = "0x0000000000000000000000000000000021676e6972747320656c706d61732061" assert felts == [ref_felt_1] buffer = bytearray("a sample string!"+"high", 'utf-8') - felts = ezkl.buffer_to_felt(buffer) + felts = ezkl.buffer_to_felts(buffer) ref_felt_2 = "0x0000000000000000000000000000000000000000000000000000000068676968" assert felts == [ref_felt_1,ref_felt_2]