diff --git a/examples/notebooks/felt_conversion_test.ipynb b/examples/notebooks/felt_conversion_test.ipynb index e3bd7f6bb..fd00ecf4e 100644 --- a/examples/notebooks/felt_conversion_test.ipynb +++ b/examples/notebooks/felt_conversion_test.ipynb @@ -2,9 +2,25 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'inputs': [['a5c7080000000000000000000000000000000000000000000000000000000000', 'b09c1c0000000000000000000000000000000000000000000000000000000000', '29fe2e0000000000000000000000000000000000000000000000000000000000', '5d7e1a0000000000000000000000000000000000000000000000000000000000', 'f3ed390000000000000000000000000000000000000000000000000000000000', '93bf370000000000000000000000000000000000000000000000000000000000', '5973130000000000000000000000000000000000000000000000000000000000', 'f760370000000000000000000000000000000000000000000000000000000000', 'f79b1b0000000000000000000000000000000000000000000000000000000000', '2dee360000000000000000000000000000000000000000000000000000000000', 'f062370000000000000000000000000000000000000000000000000000000000', '5392270000000000000000000000000000000000000000000000000000000000', '2e64270000000000000000000000000000000000000000000000000000000000', 'd2ee1f0000000000000000000000000000000000000000000000000000000000', '1c194f0000000000000000000000000000000000000000000000000000000000', '06452b0000000000000000000000000000000000000000000000000000000000', 'e777330000000000000000000000000000000000000000000000000000000000', 'e58b3a0000000000000000000000000000000000000000000000000000000000', 'c7132e0000000000000000000000000000000000000000000000000000000000', '299e410000000000000000000000000000000000000000000000000000000000', '29813a0000000000000000000000000000000000000000000000000000000000', 'f7c04d0000000000000000000000000000000000000000000000000000000000', 'f9f7360000000000000000000000000000000000000000000000000000000000', '1a5a320000000000000000000000000000000000000000000000000000000000', '710e330000000000000000000000000000000000000000000000000000000000', 'ae1e4d0000000000000000000000000000000000000000000000000000000000', '8859140000000000000000000000000000000000000000000000000000000000', 'f3cc4e0000000000000000000000000000000000000000000000000000000000', '13970c0000000000000000000000000000000000000000000000000000000000', 'df8c240000000000000000000000000000000000000000000000000000000000', 'adee340000000000000000000000000000000000000000000000000000000000', '25b1330000000000000000000000000000000000000000000000000000000000', '43dd300000000000000000000000000000000000000000000000000000000000', 'fa2f2e0000000000000000000000000000000000000000000000000000000000', '68793d0000000000000000000000000000000000000000000000000000000000', '103b080000000000000000000000000000000000000000000000000000000000', 'fcca350000000000000000000000000000000000000000000000000000000000', '1b14370000000000000000000000000000000000000000000000000000000000', '75a5390000000000000000000000000000000000000000000000000000000000', 'a703150000000000000000000000000000000000000000000000000000000000', '5bb22e0000000000000000000000000000000000000000000000000000000000', 'e01e170000000000000000000000000000000000000000000000000000000000', '34f2400000000000000000000000000000000000000000000000000000000000', 'e68f4d0000000000000000000000000000000000000000000000000000000000', 'c230170000000000000000000000000000000000000000000000000000000000', '79392c0000000000000000000000000000000000000000000000000000000000', '6d16100000000000000000000000000000000000000000000000000000000000', 'd9c71d0000000000000000000000000000000000000000000000000000000000', 'ed5a460000000000000000000000000000000000000000000000000000000000', '150d190000000000000000000000000000000000000000000000000000000000', 'fbea330000000000000000000000000000000000000000000000000000000000', '051c1f0000000000000000000000000000000000000000000000000000000000', 'c2c30e0000000000000000000000000000000000000000000000000000000000', '767e0a0000000000000000000000000000000000000000000000000000000000', '38d80c0000000000000000000000000000000000000000000000000000000000', 'c6b5390000000000000000000000000000000000000000000000000000000000', 'd4e63f0000000000000000000000000000000000000000000000000000000000', 'bd1c150000000000000000000000000000000000000000000000000000000000', 'fa02490000000000000000000000000000000000000000000000000000000000', '85aa080000000000000000000000000000000000000000000000000000000000', '844e190000000000000000000000000000000000000000000000000000000000', '8439330000000000000000000000000000000000000000000000000000000000', '3aba270000000000000000000000000000000000000000000000000000000000', '67ae140000000000000000000000000000000000000000000000000000000000', 'eb23240000000000000000000000000000000000000000000000000000000000', 'a8980b0000000000000000000000000000000000000000000000000000000000', '85893f0000000000000000000000000000000000000000000000000000000000', '2129410000000000000000000000000000000000000000000000000000000000', '5312220000000000000000000000000000000000000000000000000000000000', '96e4250000000000000000000000000000000000000000000000000000000000', '95b01b0000000000000000000000000000000000000000000000000000000000', '1ac2120000000000000000000000000000000000000000000000000000000000', 'e4da250000000000000000000000000000000000000000000000000000000000', '18ef4b0000000000000000000000000000000000000000000000000000000000', '050f260000000000000000000000000000000000000000000000000000000000', '1dbb340000000000000000000000000000000000000000000000000000000000', '9294090000000000000000000000000000000000000000000000000000000000', '1394190000000000000000000000000000000000000000000000000000000000', 'e04a4e0000000000000000000000000000000000000000000000000000000000', '1835450000000000000000000000000000000000000000000000000000000000', '44a84c0000000000000000000000000000000000000000000000000000000000', '922b200000000000000000000000000000000000000000000000000000000000', '91d03d0000000000000000000000000000000000000000000000000000000000', 'db3e490000000000000000000000000000000000000000000000000000000000', 'c151480000000000000000000000000000000000000000000000000000000000', '74090e0000000000000000000000000000000000000000000000000000000000', '9f33170000000000000000000000000000000000000000000000000000000000', 'f3652b0000000000000000000000000000000000000000000000000000000000', 'ab2e230000000000000000000000000000000000000000000000000000000000', '94f3130000000000000000000000000000000000000000000000000000000000', '41173a0000000000000000000000000000000000000000000000000000000000', '7917340000000000000000000000000000000000000000000000000000000000', '26aa2d0000000000000000000000000000000000000000000000000000000000', 'b8ef400000000000000000000000000000000000000000000000000000000000', 'f731410000000000000000000000000000000000000000000000000000000000', '22ef2a0000000000000000000000000000000000000000000000000000000000', '71b12d0000000000000000000000000000000000000000000000000000000000', '9fa4380000000000000000000000000000000000000000000000000000000000', 'f3f2420000000000000000000000000000000000000000000000000000000000', 'af35330000000000000000000000000000000000000000000000000000000000']], 'outputs': [['a5c7080000000000000000000000000000000000000000000000000000000000', 'b09c1c0000000000000000000000000000000000000000000000000000000000', '29fe2e0000000000000000000000000000000000000000000000000000000000', '5d7e1a0000000000000000000000000000000000000000000000000000000000', 'f3ed390000000000000000000000000000000000000000000000000000000000', '93bf370000000000000000000000000000000000000000000000000000000000', '5973130000000000000000000000000000000000000000000000000000000000', 'f760370000000000000000000000000000000000000000000000000000000000', 'f79b1b0000000000000000000000000000000000000000000000000000000000', '2dee360000000000000000000000000000000000000000000000000000000000', 'f062370000000000000000000000000000000000000000000000000000000000', '5392270000000000000000000000000000000000000000000000000000000000', '2e64270000000000000000000000000000000000000000000000000000000000', 'd2ee1f0000000000000000000000000000000000000000000000000000000000', '1c194f0000000000000000000000000000000000000000000000000000000000', '06452b0000000000000000000000000000000000000000000000000000000000', 'e777330000000000000000000000000000000000000000000000000000000000', 'e58b3a0000000000000000000000000000000000000000000000000000000000', 'c7132e0000000000000000000000000000000000000000000000000000000000', '299e410000000000000000000000000000000000000000000000000000000000', '29813a0000000000000000000000000000000000000000000000000000000000', 'f7c04d0000000000000000000000000000000000000000000000000000000000', 'f9f7360000000000000000000000000000000000000000000000000000000000', '1a5a320000000000000000000000000000000000000000000000000000000000', '710e330000000000000000000000000000000000000000000000000000000000', 'ae1e4d0000000000000000000000000000000000000000000000000000000000', '8859140000000000000000000000000000000000000000000000000000000000', 'f3cc4e0000000000000000000000000000000000000000000000000000000000', '13970c0000000000000000000000000000000000000000000000000000000000', 'df8c240000000000000000000000000000000000000000000000000000000000', 'adee340000000000000000000000000000000000000000000000000000000000', '25b1330000000000000000000000000000000000000000000000000000000000', '43dd300000000000000000000000000000000000000000000000000000000000', 'fa2f2e0000000000000000000000000000000000000000000000000000000000', '68793d0000000000000000000000000000000000000000000000000000000000', '103b080000000000000000000000000000000000000000000000000000000000', 'fcca350000000000000000000000000000000000000000000000000000000000', '1b14370000000000000000000000000000000000000000000000000000000000', '75a5390000000000000000000000000000000000000000000000000000000000', 'a703150000000000000000000000000000000000000000000000000000000000', '5bb22e0000000000000000000000000000000000000000000000000000000000', 'e01e170000000000000000000000000000000000000000000000000000000000', '34f2400000000000000000000000000000000000000000000000000000000000', 'e68f4d0000000000000000000000000000000000000000000000000000000000', 'c230170000000000000000000000000000000000000000000000000000000000', '79392c0000000000000000000000000000000000000000000000000000000000', '6d16100000000000000000000000000000000000000000000000000000000000', 'd9c71d0000000000000000000000000000000000000000000000000000000000', 'ed5a460000000000000000000000000000000000000000000000000000000000', '150d190000000000000000000000000000000000000000000000000000000000', 'fbea330000000000000000000000000000000000000000000000000000000000', '051c1f0000000000000000000000000000000000000000000000000000000000', 'c2c30e0000000000000000000000000000000000000000000000000000000000', '767e0a0000000000000000000000000000000000000000000000000000000000', '38d80c0000000000000000000000000000000000000000000000000000000000', 'c6b5390000000000000000000000000000000000000000000000000000000000', 'd4e63f0000000000000000000000000000000000000000000000000000000000', 'bd1c150000000000000000000000000000000000000000000000000000000000', 'fa02490000000000000000000000000000000000000000000000000000000000', '85aa080000000000000000000000000000000000000000000000000000000000', '844e190000000000000000000000000000000000000000000000000000000000', '8439330000000000000000000000000000000000000000000000000000000000', '3aba270000000000000000000000000000000000000000000000000000000000', '67ae140000000000000000000000000000000000000000000000000000000000', 'eb23240000000000000000000000000000000000000000000000000000000000', 'a8980b0000000000000000000000000000000000000000000000000000000000', '85893f0000000000000000000000000000000000000000000000000000000000', '2129410000000000000000000000000000000000000000000000000000000000', '5312220000000000000000000000000000000000000000000000000000000000', '96e4250000000000000000000000000000000000000000000000000000000000', '95b01b0000000000000000000000000000000000000000000000000000000000', '1ac2120000000000000000000000000000000000000000000000000000000000', 'e4da250000000000000000000000000000000000000000000000000000000000', '18ef4b0000000000000000000000000000000000000000000000000000000000', '050f260000000000000000000000000000000000000000000000000000000000', '1dbb340000000000000000000000000000000000000000000000000000000000', '9294090000000000000000000000000000000000000000000000000000000000', '1394190000000000000000000000000000000000000000000000000000000000', 'e04a4e0000000000000000000000000000000000000000000000000000000000', '1835450000000000000000000000000000000000000000000000000000000000', '44a84c0000000000000000000000000000000000000000000000000000000000', '922b200000000000000000000000000000000000000000000000000000000000', '91d03d0000000000000000000000000000000000000000000000000000000000', 'db3e490000000000000000000000000000000000000000000000000000000000', 'c151480000000000000000000000000000000000000000000000000000000000', '74090e0000000000000000000000000000000000000000000000000000000000', '9f33170000000000000000000000000000000000000000000000000000000000', 'f3652b0000000000000000000000000000000000000000000000000000000000', 'ab2e230000000000000000000000000000000000000000000000000000000000', '94f3130000000000000000000000000000000000000000000000000000000000', '41173a0000000000000000000000000000000000000000000000000000000000', '7917340000000000000000000000000000000000000000000000000000000000', '26aa2d0000000000000000000000000000000000000000000000000000000000', 'b8ef400000000000000000000000000000000000000000000000000000000000', 'f731410000000000000000000000000000000000000000000000000000000000', '22ef2a0000000000000000000000000000000000000000000000000000000000', '71b12d0000000000000000000000000000000000000000000000000000000000', '9fa4380000000000000000000000000000000000000000000000000000000000', 'f3f2420000000000000000000000000000000000000000000000000000000000', 'af35330000000000000000000000000000000000000000000000000000000000']], 'max_lookup_inputs': 0, 'min_lookup_inputs': 0, 'max_range_size': 0}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/dante/Documents/GitHub/ezkl/.env/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "\n", "\n", @@ -16,13 +32,13 @@ "\n", "\n", "class Passthrough(torch.nn.Module):\n", - " def __init__(self, input_size=10):\n", + " def __init__(self, input_size=100):\n", " super().__init__()\n", "\n", " def forward(self, x):\n", " return x\n", "\n", - "def generate_random_data(size=10, min_val=1, max_val=10):\n", + "def generate_random_data(size=100, min_val=1, max_val=10):\n", " return [min_val + (max_val - min_val) * torch.rand(1).item() for _ in range(size)]\n", "\n", "def save_json(data, filename):\n", @@ -38,8 +54,10 @@ " gip_run_args.param_scale = 19\n", " gip_run_args.logrows = 8\n", " run_args = ezkl.gen_settings(py_run_args=gip_run_args)\n", + " ezkl.get_srs(commitment=ezkl.PyCommitments.KZG)\n", " ezkl.compile_circuit()\n", - " await ezkl.gen_witness()\n", + " res = await ezkl.gen_witness()\n", + " print(res)\n", " ezkl.setup()\n", " ezkl.prove(proof_path=\"proof.json\")\n", " ezkl.verify()\n", @@ -53,7 +71,7 @@ " model_shapes = settings[\"model_instance_shapes\"]\n", "\n", " flat_inputs = [x for arr in inputs[\"input_data\"] for x in arr]\n", - " scaled_inputs = [ezkl.float_to_felt(x, input_scale) for x in flat_inputs]\n", + " scaled_inputs = [ezkl.float_to_felt(x, input_scale, ezkl.PyInputType.F32) for x in flat_inputs]\n", " proof_instances = proof[\"instances\"][0]\n", "\n", " def get_group_index(i):\n", @@ -76,7 +94,7 @@ " assert scaled == instance, f\"Input mismatch at index {i}: {scaled} != {instance} ({descaled_instance} != {descaled_input} OG {flat_inputs[i]} PRETTY {pretty_value})\"\n", "\n", "model = Passthrough()\n", - "torch.onnx.export(model, torch.randn(1, 10), \"network.onnx\")\n", + "torch.onnx.export(model, torch.randn(1, 100), \"network.onnx\")\n", "\n", "input_data = {\"input_data\": [generate_random_data()]}\n", "save_json(input_data, \"input.json\")\n", diff --git a/src/bindings/python.rs b/src/bindings/python.rs index c58927da5..14c840be0 100644 --- a/src/bindings/python.rs +++ b/src/bindings/python.rs @@ -4,6 +4,7 @@ use crate::circuit::modules::poseidon::{ PoseidonChip, }; use crate::circuit::modules::Module; +use crate::circuit::InputType; use crate::circuit::{CheckMode, Tolerance}; use crate::commands::*; use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep}; @@ -316,6 +317,65 @@ impl FromStr for PyCommitments { } } +#[pyclass] +#[derive(Debug, Clone)] +#[gen_stub_pyclass_enum] +enum PyInputType { + /// + Bool, + /// + F16, + /// + F32, + /// + F64, + /// + Int, + /// + TDim, +} + +impl From for PyInputType { + fn from(input_type: InputType) -> Self { + match input_type { + InputType::Bool => PyInputType::Bool, + InputType::F16 => PyInputType::F16, + InputType::F32 => PyInputType::F32, + InputType::F64 => PyInputType::F64, + InputType::Int => PyInputType::Int, + InputType::TDim => PyInputType::TDim, + } + } +} + +impl From for InputType { + fn from(py_input_type: PyInputType) -> Self { + match py_input_type { + PyInputType::Bool => InputType::Bool, + PyInputType::F16 => InputType::F16, + PyInputType::F32 => InputType::F32, + PyInputType::F64 => InputType::F64, + PyInputType::Int => InputType::Int, + PyInputType::TDim => InputType::TDim, + } + } +} + +impl FromStr for PyInputType { + type Err = String; + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "bool" => Ok(PyInputType::Bool), + "f16" => Ok(PyInputType::F16), + "f32" => Ok(PyInputType::F32), + "f64" => Ok(PyInputType::F64), + "int" => Ok(PyInputType::Int), + "tdim" => Ok(PyInputType::TDim), + _ => Err("Invalid value for InputType".to_string()), + } + } +} + /// Converts a field element hex string to big endian /// /// Arguments @@ -396,6 +456,9 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult { /// scale: float /// The scaling factor used to quantize the float into a field element /// +/// input_type: PyInputType +/// The type of the input +/// /// Returns /// ------- /// str @@ -403,10 +466,12 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult { /// #[pyfunction(signature = ( input, - scale + scale, + input_type=PyInputType::F64 ))] #[gen_stub_pyfunction] -fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult { +fn float_to_felt(mut input: f64, scale: crate::Scale, input_type: PyInputType) -> PyResult { + InputType::roundtrip(&input_type.into(), &mut input); let int_rep = quantize_float(&input, 0.0, scale) .map_err(|_| PyIOError::new_err("Failed to quantize input"))?; let felt = integer_rep_to_felt(int_rep); @@ -1968,6 +2033,7 @@ fn ezkl(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add("__version__", env!("CARGO_PKG_VERSION"))?; m.add_function(wrap_pyfunction!(felt_to_big_endian, m)?)?; m.add_function(wrap_pyfunction!(felt_to_int, m)?)?; diff --git a/src/bindings/wasm.rs b/src/bindings/wasm.rs index 3adc41843..bcec19837 100644 --- a/src/bindings/wasm.rs +++ b/src/bindings/wasm.rs @@ -22,6 +22,7 @@ use halo2curves::{ bn256::{Bn256, Fr, G1Affine}, ff::PrimeField, }; +use std::str::FromStr; use wasm_bindgen::prelude::*; use wasm_bindgen_console_logger::DEFAULT_LOGGER; @@ -113,9 +114,15 @@ pub fn feltToFloat( #[wasm_bindgen] #[allow(non_snake_case)] pub fn floatToFelt( - input: f64, + mut input: f64, scale: crate::Scale, + input_type: &str, ) -> Result>, JsError> { + crate::circuit::InputType::roundtrip( + &crate::circuit::InputType::from_str(input_type) + .map_err(|e| JsError::new(&format!("{}", e)))?, + &mut input, + ); let int_rep = quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?; let felt = integer_rep_to_felt(int_rep); diff --git a/src/circuit/ops/errors.rs b/src/circuit/ops/errors.rs index cb488ca6b..97e6da8f3 100644 --- a/src/circuit/ops/errors.rs +++ b/src/circuit/ops/errors.rs @@ -97,4 +97,7 @@ pub enum CircuitError { /// Invalid scale #[error("negative scale for an op that requires positive inputs {0}")] NegativeScale(String), + #[error("invalid input type {0}")] + /// Invalid input type + InvalidInputType(String), } diff --git a/src/circuit/ops/mod.rs b/src/circuit/ops/mod.rs index 552a782fc..1ac8ba91c 100644 --- a/src/circuit/ops/mod.rs +++ b/src/circuit/ops/mod.rs @@ -105,7 +105,10 @@ impl InputType { } /// - pub fn roundtrip(&self, input: &mut T) { + pub fn roundtrip( + &self, + input: &mut T, + ) { match self { InputType::Bool => { let boolean_input = input.clone().to_i64().unwrap(); @@ -118,7 +121,7 @@ impl InputType { *input = T::from_f32(f32_input).unwrap(); } InputType::F32 => { - let f32_input = input.clone().to_f32().unwrap(); + let f32_input: f32 = input.clone().to_f32().unwrap(); *input = T::from_f32(f32_input).unwrap(); } InputType::F64 => { @@ -133,6 +136,22 @@ impl InputType { } } +impl std::str::FromStr for InputType { + type Err = CircuitError; + + fn from_str(s: &str) -> Result { + match s { + "bool" => Ok(InputType::Bool), + "f16" => Ok(InputType::F16), + "f32" => Ok(InputType::F32), + "f64" => Ok(InputType::F64), + "int" => Ok(InputType::Int), + "tdim" => Ok(InputType::TDim), + e => Err(CircuitError::InvalidInputType(e.to_string())), + } + } +} + /// #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] pub struct Input {