diff --git a/Cargo.lock b/Cargo.lock index 0495b9418..d63d7a8c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -655,6 +655,12 @@ dependencies = [ "constant_time_eq", ] +[[package]] +name = "block" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a" + [[package]] name = "block-buffer" version = "0.9.0" @@ -1011,6 +1017,17 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +[[package]] +name = "core-graphics-types" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "libc", +] + [[package]] name = "cpufeatures" version = "0.2.12" @@ -1793,8 +1810,10 @@ dependencies = [ "lazy_static", "log", "maybe-rayon", + "metal", "mnist", "num", + "objc", "openssl", "pg_bigdecimal", "plotters", @@ -1932,7 +1951,28 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" dependencies = [ - "foreign-types-shared", + "foreign-types-shared 0.1.1", +] + +[[package]] +name = "foreign-types" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" +dependencies = [ + "foreign-types-macros", + "foreign-types-shared 0.3.1", +] + +[[package]] +name = "foreign-types-macros" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.53", ] [[package]] @@ -1941,6 +1981,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" +[[package]] +name = "foreign-types-shared" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -2933,6 +2979,15 @@ dependencies = [ "subtle", ] +[[package]] +name = "malloc_buf" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62bb907fe88d54d8d9ce32a3cceab4218ed2f6b7d35617cafe9adf84e43919cb" +dependencies = [ + "libc", +] + [[package]] name = "maplit" version = "1.0.2" @@ -2993,6 +3048,20 @@ dependencies = [ "autocfg", ] +[[package]] +name = "metal" +version = "0.27.0" +source = "git+https://github.com/gfx-rs/metal-rs#ff8fd3d6dc7792852f8a015458d7e6d42d7fb352" +dependencies = [ + "bitflags 2.5.0", + "block", + "core-graphics-types", + "foreign-types 0.5.0", + "log", + "objc", + "paste", +] + [[package]] name = "mime" version = "0.3.17" @@ -3205,6 +3274,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +[[package]] +name = "objc" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "915b1b472bc21c53464d6c8461c9d3af805ba1ef837e1cac254428f4a77177b1" +dependencies = [ + "malloc_buf", +] + [[package]] name = "object" version = "0.32.2" @@ -3265,7 +3343,7 @@ checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" dependencies = [ "bitflags 2.5.0", "cfg-if", - "foreign-types", + "foreign-types 0.3.2", "libc", "once_cell", "openssl-macros", diff --git a/Cargo.toml b/Cargo.toml index a72adf788..d16dfb879 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ unzip-n = "0.1.2" num = "0.4.1" portable-atomic = "1.6.0" tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" } +metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true } # evm related deps @@ -83,6 +84,8 @@ pyo3-log = { version = "0.9.0", default_features = false, optional = true } tract-onnx = { git = "https://github.com/sonos/tract/", rev = "681a096f02c9d7d363102d9fb0e446d1710ac2c8", default_features = false, optional = true } tabled = { version = "0.12.0", optional = true } +objc = { version = "0.2.4", optional = true } + [target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies] colored = { version = "2.0.0", default_features = false, optional = true } @@ -198,6 +201,7 @@ det-prove = [] icicle = ["halo2_proofs/icicle_gpu"] empty-cmd = [] no-banner = [] +metal = ["dep:metal", "dep:objc"] # icicle patch to 0.1.0 if feature icicle is enabled [patch.'https://github.com/ingonyama-zk/icicle'] diff --git a/abis/QuantizeData.json b/abis/QuantizeData.json index 25e58a475..b43a96f9e 100644 --- a/abis/QuantizeData.json +++ b/abis/QuantizeData.json @@ -20,9 +20,9 @@ "name": "quantize_data", "outputs": [ { - "internalType": "int128[]", + "internalType": "int64[]", "name": "quantized_data", - "type": "int128[]" + "type": "int64[]" } ], "stateMutability": "pure", @@ -31,9 +31,9 @@ { "inputs": [ { - "internalType": "int128[]", + "internalType": "int64[]", "name": "quantized_data", - "type": "int128[]" + "type": "int64[]" } ], "name": "to_field_element", diff --git a/contracts/QuantizeData.sol b/contracts/QuantizeData.sol index ce832b42a..babbdf8e9 100644 --- a/contracts/QuantizeData.sol +++ b/contracts/QuantizeData.sol @@ -125,7 +125,7 @@ contract QuantizeData { } function to_field_element( - int128[] memory quantized_data + int64[] memory quantized_data ) public pure returns (uint256[] memory output) { output = new uint256[](quantized_data.length); for (uint i; i < quantized_data.length; i++) { diff --git a/examples/conv2d_mnist/main.rs b/examples/conv2d_mnist/main.rs index bd1dd8f85..866773b47 100644 --- a/examples/conv2d_mnist/main.rs +++ b/examples/conv2d_mnist/main.rs @@ -42,8 +42,8 @@ const NUM_INNER_COLS: usize = 1; struct Config< const LEN: usize, //LEN = CHOUT x OH x OW flattened //not supported yet in rust stable const CLASSES: usize, - const LOOKUP_MIN: i128, - const LOOKUP_MAX: i128, + const LOOKUP_MIN: i64, + const LOOKUP_MAX: i64, // Convolution const KERNEL_HEIGHT: usize, const KERNEL_WIDTH: usize, @@ -66,8 +66,8 @@ struct Config< struct MyCircuit< const LEN: usize, //LEN = CHOUT x OH x OW flattened const CLASSES: usize, - const LOOKUP_MIN: i128, - const LOOKUP_MAX: i128, + const LOOKUP_MIN: i64, + const LOOKUP_MAX: i64, // Convolution const KERNEL_HEIGHT: usize, const KERNEL_WIDTH: usize, @@ -90,8 +90,8 @@ struct MyCircuit< impl< const LEN: usize, const CLASSES: usize, - const LOOKUP_MIN: i128, - const LOOKUP_MAX: i128, + const LOOKUP_MIN: i64, + const LOOKUP_MAX: i64, // Convolution const KERNEL_HEIGHT: usize, const KERNEL_WIDTH: usize, diff --git a/examples/mlp_4d_einsum.rs b/examples/mlp_4d_einsum.rs index 55a55a2f9..bec3ecb7a 100644 --- a/examples/mlp_4d_einsum.rs +++ b/examples/mlp_4d_einsum.rs @@ -23,8 +23,8 @@ struct MyConfig { #[derive(Clone)] struct MyCircuit< const LEN: usize, //LEN = CHOUT x OH x OW flattened - const LOOKUP_MIN: i128, - const LOOKUP_MAX: i128, + const LOOKUP_MIN: i64, + const LOOKUP_MAX: i64, > { // Given the stateless MyConfig type information, a DNN trace is determined by its input and the parameters of its layers. // Computing the trace still requires a forward pass. The intermediate activations are stored only by the layouter. @@ -34,7 +34,7 @@ struct MyCircuit< _marker: PhantomData, } -impl Circuit +impl Circuit for MyCircuit { type Config = MyConfig; diff --git a/src/circuit/ops/chip.rs b/src/circuit/ops/chip.rs index f53aaddbc..cfd6a06bd 100644 --- a/src/circuit/ops/chip.rs +++ b/src/circuit/ops/chip.rs @@ -24,7 +24,7 @@ use crate::{ table::{Range, RangeCheck, Table}, utils, }, - tensor::{Tensor, TensorType, ValTensor, VarTensor}, + tensor::{IntoI64, Tensor, TensorType, ValTensor, VarTensor}, }; use std::{collections::BTreeMap, error::Error, marker::PhantomData}; @@ -345,7 +345,7 @@ pub struct BaseConfig { _marker: PhantomData, } -impl BaseConfig { +impl BaseConfig { /// Returns a new [BaseConfig] with no inputs, no selectors, and no tables. pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self { Self { diff --git a/src/circuit/ops/hybrid.rs b/src/circuit/ops/hybrid.rs index 0dd89dd66..463723319 100644 --- a/src/circuit/ops/hybrid.rs +++ b/src/circuit/ops/hybrid.rs @@ -1,7 +1,7 @@ use super::*; use crate::{ circuit::{layouts, utils, Tolerance}, - fieldutils::i128_to_felt, + fieldutils::i64_to_felt, graph::multiplier_to_scale, tensor::{self, Tensor, TensorType, ValTensor}, }; @@ -71,7 +71,7 @@ pub enum HybridOp { }, } -impl Op for HybridOp { +impl Op for HybridOp { /// fn requires_homogenous_input_scales(&self) -> Vec { match self { @@ -184,8 +184,8 @@ impl Op for Hybrid config, region, values[..].try_into()?, - i128_to_felt(input_scale.0 as i128), - i128_to_felt(output_scale.0 as i128), + i64_to_felt(input_scale.0 as i64), + i64_to_felt(output_scale.0 as i64), )? } else { layouts::nonlinearity( @@ -209,7 +209,7 @@ impl Op for Hybrid config, region, values[..].try_into()?, - i128_to_felt(denom.0 as i128), + i64_to_felt(denom.0 as i64), )? } else { layouts::nonlinearity( diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index 93581e2fe..1fc33842b 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -14,7 +14,7 @@ use maybe_rayon::{ slice::ParallelSliceMut, }; -use self::tensor::{create_constant_tensor, create_zero_tensor}; +use self::tensor::{create_constant_tensor, create_zero_tensor, IntoI64}; use super::{ chip::{BaseConfig, CircuitError}, @@ -22,7 +22,7 @@ use super::{ }; use crate::{ circuit::{ops::base::BaseOp, utils}, - fieldutils::{felt_to_i128, i128_to_felt}, + fieldutils::{felt_to_i64, i64_to_felt}, tensor::{ create_unit_tensor, get_broadcasted_shape, ops::{accumulated, add, mult, sub}, @@ -34,7 +34,7 @@ use super::*; use crate::circuit::ops::lookup::LookupOp; /// Same as div but splits the division into N parts -pub(crate) fn loop_div( +pub(crate) fn loop_div( config: &BaseConfig, region: &mut RegionCtx, value: &[ValTensor; 1], @@ -48,8 +48,8 @@ pub(crate) fn loop_div (2_i128.pow(F::S - 4)) { - divisor = i128_to_felt(felt_to_i128(divisor) / 2); + while felt_to_i64(divisor) % 2 == 0 && felt_to_i64(divisor) > (2_i64.pow(F::S - 4)) { + divisor = i64_to_felt(felt_to_i64(divisor) / 2); num_parts += 1; } @@ -58,9 +58,9 @@ pub(crate) fn loop_div( +pub(crate) fn div( config: &BaseConfig, region: &mut RegionCtx, value: &[ValTensor; 1], @@ -82,7 +82,7 @@ pub(crate) fn div( let input = value[0].clone(); let input_dims = input.dims(); - let range_check_bracket = felt_to_i128(div) / 2; + let range_check_bracket = felt_to_i64(div) / 2; let divisor = create_constant_tensor(div, 1); @@ -93,9 +93,9 @@ pub(crate) fn div( let mut claimed_output: ValTensor = if is_assigned { let input_evals = input.get_int_evals()?; - tensor::ops::nonlinearities::const_div(&input_evals.clone(), felt_to_i128(div) as f64) + tensor::ops::nonlinearities::const_div(&input_evals.clone(), felt_to_i64(div) as f64) .par_iter() - .map(|x| Value::known(i128_to_felt(*x))) + .map(|x| Value::known(i64_to_felt(*x))) .collect::>>() .into() } else { @@ -134,7 +134,7 @@ pub(crate) fn div( } /// recip accumulated layout -pub(crate) fn recip( +pub(crate) fn recip( config: &BaseConfig, region: &mut RegionCtx, value: &[ValTensor; 1], @@ -144,14 +144,14 @@ pub(crate) fn recip( let input = value[0].clone(); let input_dims = input.dims(); - let integer_input_scale = felt_to_i128(input_scale); - let integer_output_scale = felt_to_i128(output_scale); + let integer_input_scale = felt_to_i64(input_scale); + let integer_output_scale = felt_to_i64(output_scale); // range_check_bracket is min of input_scale * output_scale and 2^F::S - 3 - let range_check_len = std::cmp::min(integer_output_scale, 2_i128.pow(F::S - 4)); + let range_check_len = std::cmp::min(integer_output_scale, 2_i64.pow(F::S - 4)); let input_scale_ratio = if range_check_len > 0 { - i128_to_felt(integer_input_scale * integer_output_scale / range_check_len) + i64_to_felt(integer_input_scale * integer_output_scale / range_check_len) } else { F::ONE }; @@ -164,11 +164,11 @@ pub(crate) fn recip( let input_evals = input.get_int_evals()?; tensor::ops::nonlinearities::recip( &input_evals, - felt_to_i128(input_scale) as f64, - felt_to_i128(output_scale) as f64, + felt_to_i64(input_scale) as f64, + felt_to_i64(output_scale) as f64, ) .par_iter() - .map(|x| Value::known(i128_to_felt(*x))) + .map(|x| Value::known(i64_to_felt(*x))) .collect::>>() .into() } else { @@ -194,8 +194,8 @@ pub(crate) fn recip( let rebased_div = loop_div(config, region, &[product], input_scale_ratio)?; let zero_inverse_val = - tensor::ops::nonlinearities::zero_recip(felt_to_i128(output_scale) as f64)[0]; - let zero_inverse = create_constant_tensor(i128_to_felt(zero_inverse_val), 1); + tensor::ops::nonlinearities::zero_recip(felt_to_i64(output_scale) as f64)[0]; + let zero_inverse = create_constant_tensor(i64_to_felt(zero_inverse_val), 1); let equal_zero_mask = equals_zero(config, region, &[input.clone()])?; @@ -208,7 +208,7 @@ pub(crate) fn recip( &[equal_zero_mask.clone(), equal_inverse_mask], )?; - let unit_scale = create_constant_tensor(i128_to_felt(range_check_len), 1); + let unit_scale = create_constant_tensor(i64_to_felt(range_check_len), 1); let unit_mask = pairwise(config, region, &[equal_zero_mask, unit_scale], BaseOp::Mult)?; @@ -237,19 +237,19 @@ pub(crate) fn recip( /// use ezkl::circuit::layouts::dot; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]), /// &[1, 3, 3], /// ).unwrap()); -/// let y = ValTensor::from_i128_tensor(Tensor::::new( +/// let y = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[5, 5, 10, -4, 2, -1, 2, 0, 1]), /// &[1, 3, 3], /// ).unwrap()); /// assert_eq!(dot::(&dummy_config, &mut dummy_region, &[x, y]).unwrap().get_int_evals().unwrap()[0], 86); /// ``` -pub fn dot( +pub fn dot( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 2], @@ -369,172 +369,172 @@ pub fn dot( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// /// // matmul case -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 1, 2, 1, 1, 1]), /// &[2, 3], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 3, 2, 1, 1, 1]), /// &[3, 2], /// ).unwrap()); /// let result = einsum::(&dummy_config, &mut dummy_region, &[x, k], "ij,jk->ik").unwrap(); -/// let expected = Tensor::::new(Some(&[8, 9, 5, 5]), &[2, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[8, 9, 5, 5]), &[2, 2]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// /// // element wise multiplication -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]), /// &[3, 3], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 2, 3, 1, 2, 3, 1, 2, 3]), /// &[3, 3], /// ).unwrap()); /// let result = einsum::(&dummy_config, &mut dummy_region, &[x, k], "ij,ij->ij").unwrap(); -/// let expected = Tensor::::new(Some(&[1, 4, 9, 2, 6, 12, 3, 8, 15]), &[3, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 4, 9, 2, 6, 12, 3, 8, 15]), &[3, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// /// /// // dot product of A with the transpose of B. -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]), /// &[3, 3], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 2, 3, 1, 2, 3, 1, 2, 3]), /// &[3, 3], /// ).unwrap()); /// let result = einsum::(&dummy_config, &mut dummy_region, &[x, k], "ik,jk->ij").unwrap(); -/// let expected = Tensor::::new(Some(&[14, 14, 14, 20, 20, 20, 26, 26, 26]), &[3, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[14, 14, 14, 20, 20, 20, 26, 26, 26]), &[3, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// /// // dot product -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5]), /// &[3, 3], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 2, 3, 1, 2, 3, 1, 2, 3]), /// &[3, 3], /// ).unwrap()); /// let result = einsum::(&dummy_config, &mut dummy_region, &[x, k], "ik,ik->i").unwrap(); -/// let expected = Tensor::::new(Some(&[14, 20, 26]), &[3]).unwrap(); +/// let expected = Tensor::::new(Some(&[14, 20, 26]), &[3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// /// /// // dot product -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 2, 3]), /// &[3], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 2, 3]), /// &[3], /// ).unwrap()); /// let result = einsum::(&dummy_config, &mut dummy_region, &[x, k], "i,i->").unwrap(); -/// let expected = Tensor::::new(Some(&[14]), &[1]).unwrap(); +/// let expected = Tensor::::new(Some(&[14]), &[1]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// /// /// // wut ? -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]), /// &[3, 3, 2], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[4, 5, 7, 8]), /// &[2, 2], /// ).unwrap()); /// let result = einsum::(&dummy_config, &mut dummy_region, &[x, k], "anm,bm->ba").unwrap(); -/// let expected = Tensor::::new(Some(&[68, 80, 95, 113, 134, 158]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[68, 80, 95, 113, 134, 158]), &[2, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// /// // wutttttt ? -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]), /// &[3, 3, 2], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[4, 5, 7, 8]), /// &[2, 2], /// ).unwrap()); -/// let z = ValTensor::from_i128_tensor(Tensor::::new( +/// let z = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[4, 5, 7, 8, 9, 9]), /// &[2, 3], /// ).unwrap()); /// /// let result = einsum::(&dummy_config, &mut dummy_region, &[z, x, k], "bn,anm,bm->ba").unwrap(); -/// let expected = Tensor::::new(Some(&[390, 414, 534, 994, 1153, 1384]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[390, 414, 534, 994, 1153, 1384]), &[2, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// /// /// // contraction with a single common axis -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]), /// &[3, 3, 2], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[4, 5, 7, 8]), /// &[2, 2], /// ).unwrap()); /// let result = einsum::(&dummy_config, &mut dummy_region, &[x, k], "abc,cd->").unwrap(); -/// let expected = Tensor::::new(Some(&[648]), &[1]).unwrap(); +/// let expected = Tensor::::new(Some(&[648]), &[1]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// /// // contraction with no common axes (outer product) -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 2, 3, 2, 3, 4, 3, 4, 5, 1, 2, 3, 2, 3, 4, 3, 4, 5]), /// &[3, 3, 2], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[4, 5, 7, 8]), /// &[2, 2], /// ).unwrap()); /// let result = einsum::(&dummy_config, &mut dummy_region, &[x, k], "abc,ed->").unwrap(); -/// let expected = Tensor::::new(Some(&[1296]), &[1]).unwrap(); +/// let expected = Tensor::::new(Some(&[1296]), &[1]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// /// // trivial axes mapping -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[4, 5, 7, 8]), /// &[2, 2], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[4, 5]), /// &[2], /// ).unwrap()); /// /// let result = einsum::(&dummy_config, &mut dummy_region, &[x.clone(), k.clone()], "mk,k->m").unwrap(); -/// let expected = Tensor::::new(Some(&[41, 68]), &[2]).unwrap(); +/// let expected = Tensor::::new(Some(&[41, 68]), &[2]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// /// let result = einsum::(&dummy_config, &mut dummy_region, &[x, k], "mk,k->mn").unwrap(); -/// let expected = Tensor::::new(Some(&[41, 68]), &[2, 1]).unwrap(); +/// let expected = Tensor::::new(Some(&[41, 68]), &[2, 1]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[0, 0, 0, 3]), /// &[1, 4], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[213, 227, 74, 77]), /// &[4], /// ).unwrap()); /// /// let result = einsum::(&dummy_config, &mut dummy_region, &[x.clone(), k.clone()], "mk,k->ma").unwrap(); -/// let expected = Tensor::::new(Some(&[231]), &[1, 1]).unwrap(); +/// let expected = Tensor::::new(Some(&[231]), &[1, 1]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// // subtle difference /// let result = einsum::(&dummy_config, &mut dummy_region, &[x.clone(), k.clone()], "mk,n->ma").unwrap(); -/// let expected = Tensor::::new(Some(&[1773]), &[1, 1]).unwrap(); +/// let expected = Tensor::::new(Some(&[1773]), &[1, 1]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// /// ``` /// -pub fn einsum( +pub fn einsum( config: &BaseConfig, region: &mut RegionCtx, inputs: &[ValTensor], @@ -753,7 +753,7 @@ pub fn einsum( Ok(output) } -fn _sort_ascending( +fn _sort_ascending( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -768,7 +768,7 @@ fn _sort_ascending( int_evals.par_sort_unstable_by(|a, b| a.cmp(b)); int_evals .par_iter() - .map(|x| Value::known(i128_to_felt(*x))) + .map(|x| Value::known(i64_to_felt(*x))) .collect::>>() } else { Tensor::new( @@ -797,7 +797,7 @@ fn _sort_ascending( } /// Returns top K values. -fn _select_topk( +fn _select_topk( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -822,20 +822,20 @@ fn _select_topk( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 15, 2, 1, 1, 0]), /// &[2,3], /// ).unwrap()); /// let result = topk_axes::(&dummy_config, &mut dummy_region, &[x], 2, 1, true).unwrap(); -/// let expected = Tensor::::new( +/// let expected = Tensor::::new( /// Some(&[15, 2, 1, 1]), /// &[2,2], /// ).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` -pub fn topk_axes( +pub fn topk_axes( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -855,7 +855,7 @@ pub fn topk_axes( Ok(output) } -fn select( +fn select( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 2], @@ -894,7 +894,7 @@ fn select( Ok(assigned_output) } -fn one_hot( +fn one_hot( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -911,7 +911,7 @@ fn one_hot( let int_evals = input.get_int_evals()?; let res = tensor::ops::one_hot(&int_evals, num_classes, 1)?; res.par_iter() - .map(|x| Value::known(i128_to_felt(*x))) + .map(|x| Value::known(i64_to_felt(*x))) .collect::>() } else { Tensor::new( @@ -946,7 +946,9 @@ fn one_hot( } /// Dynamic lookup -pub(crate) fn dynamic_lookup( +pub(crate) fn dynamic_lookup< + F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64, +>( config: &BaseConfig, region: &mut RegionCtx, lookups: &[ValTensor; 2], @@ -1036,7 +1038,7 @@ pub(crate) fn dynamic_lookup( +pub(crate) fn shuffles( config: &BaseConfig, region: &mut RegionCtx, input: &[ValTensor; 1], @@ -1099,7 +1101,7 @@ pub(crate) fn shuffles( +pub(crate) fn one_hot_axis( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -1152,7 +1154,7 @@ pub(crate) fn one_hot_axis( +pub(crate) fn gather( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 2], @@ -1180,7 +1182,9 @@ pub(crate) fn gather( } /// Gather accumulated layout -pub(crate) fn gather_elements( +pub(crate) fn gather_elements< + F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64, +>( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 2], @@ -1203,7 +1207,7 @@ pub(crate) fn gather_elements( +pub(crate) fn gather_nd( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 2], @@ -1254,7 +1258,9 @@ pub(crate) fn gather_nd( +pub(crate) fn linearize_element_index< + F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64, +>( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -1359,7 +1365,9 @@ pub(crate) fn linearize_element_index( +pub(crate) fn linearize_nd_index< + F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64, +>( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -1487,7 +1495,7 @@ pub(crate) fn linearize_nd_index() as i128), + .all(|x| *x < dims.iter().product::() as i64), "res is greater than the product of the dims {} (coord={}, index_dim_multiplier={}, res={})", dims.iter().product::(), index_val.show(), @@ -1509,7 +1517,7 @@ pub(crate) fn linearize_nd_index( config: &BaseConfig, region: &mut RegionCtx, @@ -1542,7 +1550,7 @@ pub(crate) fn get_missing_set_elements< fullset_evals .par_iter() - .map(|x| Value::known(i128_to_felt(*x))) + .map(|x| Value::known(i64_to_felt(*x))) .collect::>>() .into() } else { @@ -1574,7 +1582,9 @@ pub(crate) fn get_missing_set_elements< } /// Gather accumulated layout -pub(crate) fn scatter_elements( +pub(crate) fn scatter_elements< + F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64, +>( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 3], @@ -1599,7 +1609,7 @@ pub(crate) fn scatter_elements>>() .into() } else { @@ -1656,7 +1666,7 @@ pub(crate) fn scatter_elements( +pub(crate) fn scatter_nd( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 3], @@ -1678,7 +1688,7 @@ pub(crate) fn scatter_nd>>() .into() } else { @@ -1744,9 +1754,9 @@ pub(crate) fn scatter_nd::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 15, 2, 1, 1, 0]), /// &[2, 3], /// ).unwrap()); @@ -1754,7 +1764,7 @@ pub(crate) fn scatter_nd( +pub fn sum( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -1851,9 +1861,9 @@ pub fn sum( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 15, 2, 1, 1, 0]), /// &[2, 3], /// ).unwrap()); @@ -1861,7 +1871,7 @@ pub fn sum( /// let expected = 0; /// assert_eq!(result.get_int_evals().unwrap()[0], expected); /// ``` -pub fn prod( +pub fn prod( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -1941,7 +1951,7 @@ pub fn prod( } /// Axes wise op wrapper -fn axes_wise_op( +fn axes_wise_op( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -2012,20 +2022,20 @@ fn axes_wise_op( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 15, 2, 1, 1, 0]), /// &[2, 3], /// ).unwrap()); /// let result = prod_axes::(&dummy_config, &mut dummy_region, &[x], &[1]).unwrap(); -/// let expected = Tensor::::new( +/// let expected = Tensor::::new( /// Some(&[60, 0]), /// &[2, 1], /// ).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` -pub fn prod_axes( +pub fn prod_axes( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -2046,20 +2056,20 @@ pub fn prod_axes( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 15, 2, 1, 1, 0]), /// &[2, 3], /// ).unwrap()); /// let result = sum_axes::(&dummy_config, &mut dummy_region, &[x], &[1]).unwrap(); -/// let expected = Tensor::::new( +/// let expected = Tensor::::new( /// Some(&[19, 2]), /// &[2, 1], /// ).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` -pub fn sum_axes( +pub fn sum_axes( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -2080,20 +2090,20 @@ pub fn sum_axes( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 15, 2, 1, 1, 0]), /// &[2, 3], /// ).unwrap()); /// let result = argmax_axes::(&dummy_config, &mut dummy_region, &[x], 1).unwrap(); -/// let expected = Tensor::::new( +/// let expected = Tensor::::new( /// Some(&[1, 0]), /// &[2, 1], /// ).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` -pub fn argmax_axes( +pub fn argmax_axes( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -2121,19 +2131,19 @@ pub fn argmax_axes( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 15, 2, 1, 1, 0]), /// &[2, 3], /// ).unwrap()); /// let result = max_axes::(&dummy_config, &mut dummy_region, &[x], &[1]).unwrap(); -/// let expected = Tensor::::new( +/// let expected = Tensor::::new( /// Some(&[15, 1]), /// &[2, 1], /// ).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` -pub fn max_axes( +pub fn max_axes( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -2155,20 +2165,20 @@ pub fn max_axes( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 15, 2, 1, 1, 0]), /// &[2, 3], /// ).unwrap()); /// let result = argmin_axes::(&dummy_config, &mut dummy_region, &[x], 1).unwrap(); -/// let expected = Tensor::::new( +/// let expected = Tensor::::new( /// Some(&[0, 2]), /// &[2, 1], /// ).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` -pub fn argmin_axes( +pub fn argmin_axes( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -2200,20 +2210,20 @@ pub fn argmin_axes( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 15, 2, 1, 1, 0]), /// &[2, 3], /// ).unwrap()); /// let result = min_axes::(&dummy_config, &mut dummy_region, &[x], &[1]).unwrap(); -/// let expected = Tensor::::new( +/// let expected = Tensor::::new( /// Some(&[2, 0]), /// &[2, 1], /// ).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` -pub fn min_axes( +pub fn min_axes( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -2225,7 +2235,7 @@ pub fn min_axes( } /// Pairwise (elementwise) op layout -pub(crate) fn pairwise( +pub(crate) fn pairwise( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 2], @@ -2390,20 +2400,20 @@ pub(crate) fn pairwise::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 15, 2, 1, 1, 0]), /// &[2, 3], /// ).unwrap()); /// let result = mean_of_squares_axes::(&dummy_config, &mut dummy_region, &[x], &[1]).unwrap(); -/// let expected = Tensor::::new( +/// let expected = Tensor::::new( /// Some(&[78, 1]), /// &[2, 1], /// ).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` -pub fn mean_of_squares_axes( +pub fn mean_of_squares_axes( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -2419,7 +2429,7 @@ pub fn mean_of_squares_axes( +pub(crate) fn expand( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -2445,21 +2455,21 @@ pub(crate) fn expand( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let a = ValTensor::from_i128_tensor(Tensor::::new( +/// let a = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 12, 6, 4, 5, 6]), /// &[2, 3], /// ).unwrap()); -/// let b = ValTensor::from_i128_tensor(Tensor::::new( +/// let b = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 6]), /// &[2, 3], /// ).unwrap()); /// let result = greater::(&dummy_config, &mut dummy_region, &[a,b]).unwrap(); -/// let expected = Tensor::::new(Some(&[0, 1, 1, 0, 0, 0]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 1, 1, 0, 0, 0]), &[2, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` -pub fn greater( +pub fn greater( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 2], @@ -2495,22 +2505,22 @@ pub fn greater( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// /// -/// let a = ValTensor::from_i128_tensor(Tensor::::new( +/// let a = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 12, 6, 4, 3, 2]), /// &[2, 3], /// ).unwrap()); -/// let b = ValTensor::from_i128_tensor(Tensor::::new( +/// let b = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 4]), /// &[2, 3], /// ).unwrap()); /// let result = greater_equal::(&dummy_config, &mut dummy_region, &[a,b]).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 1, 1, 1, 0, 0]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 1, 1, 1, 0, 0]), &[2, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` -pub fn greater_equal( +pub fn greater_equal( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 2], @@ -2547,22 +2557,22 @@ pub fn greater_equal( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let a = ValTensor::from_i128_tensor(Tensor::::new( +/// let a = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 0, 5, 4, 5, 1]), /// &[2, 3], /// ).unwrap()); -/// let b = ValTensor::from_i128_tensor(Tensor::::new( +/// let b = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 6]), /// &[2, 3], /// ).unwrap()); /// let result = less::(&dummy_config, &mut dummy_region, &[a,b]).unwrap(); -/// let expected = Tensor::::new(Some(&[0, 1, 0, 0, 0, 1]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 1, 0, 0, 0, 1]), &[2, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` /// -pub fn less( +pub fn less( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 2], @@ -2586,22 +2596,22 @@ pub fn less( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let a = ValTensor::from_i128_tensor(Tensor::::new( +/// let a = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 0, 5, 4, 5, 1]), /// &[2, 3], /// ).unwrap()); -/// let b = ValTensor::from_i128_tensor(Tensor::::new( +/// let b = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 6]), /// &[2, 3], /// ).unwrap()); /// let result = less_equal::(&dummy_config, &mut dummy_region, &[a,b]).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 1, 0, 1, 1, 1]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 1, 0, 1, 1, 1]), &[2, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` /// -pub fn less_equal( +pub fn less_equal( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 2], @@ -2625,21 +2635,21 @@ pub fn less_equal( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let a = ValTensor::from_i128_tensor(Tensor::::new( +/// let a = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 1, 1, 1, 1, 0]), /// &[2, 3], /// ).unwrap()); -/// let b = ValTensor::from_i128_tensor(Tensor::::new( +/// let b = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 0, 1, 0, 1, 0]), /// &[2, 3], /// ).unwrap()); /// let result = and::(&dummy_config, &mut dummy_region, &[a,b]).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 0, 1, 0, 1, 0]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 0, 1, 0, 1, 0]), &[2, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` -pub fn and( +pub fn and( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 2], @@ -2667,21 +2677,21 @@ pub fn and( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let a = ValTensor::from_i128_tensor(Tensor::::new( +/// let a = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 1, 1, 1, 1, 0]), /// &[2, 3], /// ).unwrap()); -/// let b = ValTensor::from_i128_tensor(Tensor::::new( +/// let b = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 0, 1, 0, 1, 0]), /// &[2, 3], /// ).unwrap()); /// let result = or::(&dummy_config, &mut dummy_region, &[a,b]).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 1, 1, 1, 1, 0]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 1, 1, 1, 1, 0]), &[2, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` -pub fn or( +pub fn or( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 2], @@ -2713,21 +2723,21 @@ pub fn or( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let a = ValTensor::from_i128_tensor(Tensor::::new( +/// let a = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 1, 1, 1, 1, 0]), /// &[2, 3], /// ).unwrap()); -/// let b = ValTensor::from_i128_tensor(Tensor::::new( +/// let b = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 0, 1, 0, 1, 0]), /// &[2, 3], /// ).unwrap()); /// let result = equals::(&dummy_config, &mut dummy_region, &[a,b]).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 0, 1, 0, 1, 1]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 0, 1, 0, 1, 1]), &[2, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` -pub fn equals( +pub fn equals( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 2], @@ -2737,7 +2747,7 @@ pub fn equals( } /// Equality boolean operation -pub(crate) fn equals_zero( +pub(crate) fn equals_zero( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -2785,22 +2795,22 @@ pub(crate) fn equals_zero::new( +/// let a = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 1, 1, 1, 1, 0]), /// &[2, 3], /// ).unwrap()); -/// let b = ValTensor::from_i128_tensor(Tensor::::new( +/// let b = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 0, 1, 0, 1, 0]), /// &[2, 3], /// ).unwrap()); /// let result = xor::(&dummy_config, &mut dummy_region, &[a,b]).unwrap(); -/// let expected = Tensor::::new(Some(&[0, 1, 0, 1, 0, 0]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 1, 0, 1, 0, 0]), &[2, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` /// -pub fn xor( +pub fn xor( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 2], @@ -2839,17 +2849,17 @@ pub fn xor( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 1, 1, 1, 1, 0]), /// &[2, 3], /// ).unwrap()); /// let result = not::(&dummy_config, &mut dummy_region, &[x]).unwrap(); -/// let expected = Tensor::::new(Some(&[0, 0, 0, 0, 0, 1]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 0, 0, 0, 0, 1]), &[2, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` -pub fn not( +pub fn not( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -2880,25 +2890,25 @@ pub fn not( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let mask = ValTensor::from_i128_tensor(Tensor::::new( +/// let mask = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 0, 1, 0, 1, 0]), /// &[2, 3], /// ).unwrap()); -/// let a = ValTensor::from_i128_tensor(Tensor::::new( +/// let a = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 6]), /// &[2, 3], /// ).unwrap()); -/// let b = ValTensor::from_i128_tensor(Tensor::::new( +/// let b = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[7, 8, 9, 10, 11, 12]), /// &[2, 3], /// ).unwrap()); /// let result = iff::(&dummy_config, &mut dummy_region, &[mask, a, b]).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 8, 3, 10, 5, 12]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 8, 3, 10, 5, 12]), &[2, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` -pub fn iff( +pub fn iff( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 3], @@ -2936,17 +2946,17 @@ pub fn iff( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 1, 2, 1, 1, 1]), /// &[2, 3], /// ).unwrap()); /// let result = neg::(&dummy_config, &mut dummy_region, &[x]).unwrap(); -/// let expected = Tensor::::new(Some(&[-2, -1, -2, -1, -1, -1]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[-2, -1, -2, -1, -1, -1]), &[2, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` -pub fn neg( +pub fn neg( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -2966,23 +2976,23 @@ pub fn neg( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]), /// &[1, 1, 3, 3], /// ).unwrap()); /// let pooled = sumpool::(&dummy_config, &mut dummy_region, &[x.clone()], &vec![(0, 0); 2], &vec![1;2], &vec![2, 2], false).unwrap(); -/// let expected: Tensor = Tensor::::new(Some(&[11, 8, 8, 10]), &[1, 1, 2, 2]).unwrap(); +/// let expected: Tensor = Tensor::::new(Some(&[11, 8, 8, 10]), &[1, 1, 2, 2]).unwrap(); /// assert_eq!(pooled.get_int_evals().unwrap(), expected); /// /// // This time with normalization /// let pooled = sumpool::(&dummy_config, &mut dummy_region, &[x], &vec![(0, 0); 2], &vec![1;2], &vec![2, 2], true).unwrap(); -/// let expected: Tensor = Tensor::::new(Some(&[3, 2, 2, 3]), &[1, 1, 2, 2]).unwrap(); +/// let expected: Tensor = Tensor::::new(Some(&[3, 2, 2, 3]), &[1, 1, 2, 2]).unwrap(); /// assert_eq!(pooled.get_int_evals().unwrap(), expected); /// ``` -pub fn sumpool( +pub fn sumpool( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor], @@ -3046,19 +3056,19 @@ pub fn sumpool( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]), /// &[1, 1, 3, 3], /// ).unwrap()); /// let pooled = max_pool::(&dummy_config, &mut dummy_region, &[x], &vec![(0, 0); 2], &vec![1;2], &vec![2;2]).unwrap(); -/// let expected: Tensor = Tensor::::new(Some(&[5, 4, 4, 6]), &[1, 1, 2, 2]).unwrap(); +/// let expected: Tensor = Tensor::::new(Some(&[5, 4, 4, 6]), &[1, 1, 2, 2]).unwrap(); /// assert_eq!(pooled.get_int_evals().unwrap(), expected); /// /// ``` -pub fn max_pool( +pub fn max_pool( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -3145,121 +3155,127 @@ pub fn max_pool( /// use ezkl::tensor::ValTensor; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let c = ValTensor::from_i128_tensor(Tensor::::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 2, 2, 3]).unwrap()); -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let c = ValTensor::from_i64_tensor(Tensor::::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 2, 2, 3]).unwrap()); +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 4, 0, 1]), /// &[1, 1, 2, 2], /// ).unwrap()); /// /// let result = deconv::(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![1;2], &vec![2;2]).unwrap(); -/// let expected = Tensor::::new(Some(&[0, 32, 0, 32, 0, 6, 0, 12, 0, 4, 0, 8, 0, 4, 0, 8, 0, 0, 0, 3, 0, 0, 0, 2]), &[1, 2, 3, 4]).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 32, 0, 32, 0, 6, 0, 12, 0, 4, 0, 8, 0, 4, 0, 8, 0, 0, 0, 3, 0, 0, 0, 2]), &[1, 2, 3, 4]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 4, 0, 1]), /// &[1, 1, 2, 2], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[3, 1, 1, 5]), /// &[1, 1, 2, 2], /// ).unwrap()); /// let result = deconv::(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2]).unwrap(); -/// let expected = Tensor::::new(Some(&[6, 14, 4, 2, 17, 21, 0, 1, 5]), &[1, 1, 3, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[6, 14, 4, 2, 17, 21, 0, 1, 5]), &[1, 1, 3, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 4, 0, 1]), /// &[1, 1, 2, 2], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[3, 1, 1, 5]), /// &[1, 1, 2, 2], /// ).unwrap()); /// let result = deconv::(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![1;2]).unwrap(); -/// let expected = Tensor::::new(Some(&[17]), &[1, 1, 1, 1]).unwrap(); +/// let expected = Tensor::::new(Some(&[17]), &[1, 1, 1, 1]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 4, 0, 1]), /// &[1, 1, 2, 2], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[3, 1, 1, 5]), /// &[1, 1, 2, 2], /// ).unwrap()); /// let result = deconv::(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2]).unwrap(); -/// let expected = Tensor::::new(Some(&[10, 4, 0, 3]), &[1, 1, 2, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[10, 4, 0, 3]), &[1, 1, 2, 2]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 4, 0, 1]), /// &[1, 1, 2, 2], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[3, 1, 1, 5]), /// &[1, 1, 2, 2], /// ).unwrap()); /// let result = deconv::(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2]).unwrap(); -/// let expected = Tensor::::new(Some(&[6, 2, 12, 4, 2, 10, 4, 20, 0, 0, 3, 1, 0, 0, 1, 5]), &[1, 1, 4, 4]).unwrap(); +/// let expected = Tensor::::new(Some(&[6, 2, 12, 4, 2, 10, 4, 20, 0, 0, 3, 1, 0, 0, 1, 5]), &[1, 1, 4, 4]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 4, 0, 1]), /// &[1, 1, 2, 2], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[3, 2]), /// &[1, 1, 2, 1], /// ).unwrap()); /// let result = deconv::(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2]).unwrap(); -/// let expected = Tensor::::new(Some(&[0, 0]), &[1, 1, 2, 1]).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 0]), &[1, 1, 2, 1]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 4, 0, 1]), /// &[1, 1, 2, 2], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[3, 2]), /// &[1, 1, 2, 1], /// ).unwrap()); /// let result = deconv::(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2]).unwrap(); -/// let expected = Tensor::::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 1, 4, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 1, 4, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// /// -/// let c = ValTensor::from_i128_tensor(Tensor::::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 2, 2, 3]).unwrap()); -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let c = ValTensor::from_i64_tensor(Tensor::::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 2, 2, 3]).unwrap()); +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 4, 0, 1]), /// &[1, 1, 2, 2], /// ).unwrap()); /// /// let result = deconv::(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![0;2], &vec![2;2]).unwrap(); -/// let expected = Tensor::::new(Some(&[0, 32, 0, 0, 6, 0, 0, 4, 0, 0, 0, 0]), &[1, 2, 2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 32, 0, 0, 6, 0, 0, 4, 0, 0, 0, 0]), &[1, 2, 2, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[3, 8, 0, 8, 4, 9, 8, 1, 8]), /// &[1, 1, 3, 3], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 0, 4, 6]), /// &[1, 1, 2, 2], /// ).unwrap()); -/// let b = ValTensor::from_i128_tensor(Tensor::::new( +/// let b = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1]), /// &[1], /// ).unwrap()); /// let result = deconv::(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(1, 1); 2], &vec![0;2], &vec![1;2]).unwrap(); -/// let expected = Tensor::::new(Some(&[55, 58, 66, 69]), &[1, 1, 2, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[55, 58, 66, 69]), &[1, 1, 2, 2]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// /// ``` pub fn deconv< - F: PrimeField + TensorType + PartialOrd + std::hash::Hash + std::marker::Send + std::marker::Sync, + F: PrimeField + + TensorType + + PartialOrd + + std::hash::Hash + + std::marker::Send + + std::marker::Sync + + IntoI64, >( config: &BaseConfig, region: &mut RegionCtx, @@ -3368,63 +3384,69 @@ pub fn deconv< /// use ezkl::circuit::BaseConfig; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]), /// &[1, 1, 3, 3], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[5, 1, 1, 1]), /// &[1, 1, 2, 2], /// ).unwrap()); -/// let b = ValTensor::from_i128_tensor(Tensor::::new( +/// let b = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[0]), /// &[1], /// ).unwrap()); /// let result = conv::(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2]).unwrap(); -/// let expected = Tensor::::new(Some(&[31, 16, 8, 26]), &[1, 1, 2, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[31, 16, 8, 26]), &[1, 1, 2, 2]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// /// // Now test single channel -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6, 5, 2, 3, 0, 4, -1, 3, 1, 6]), /// &[1, 2, 3, 3], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[5, 1, 1, 1, 5, 2, 1, 1]), /// &[2, 1, 2, 2], /// ).unwrap()); -/// let b = ValTensor::from_i128_tensor(Tensor::::new( +/// let b = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 1]), /// &[2], /// ).unwrap()); /// /// let result = conv::(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2]).unwrap(); -/// let expected = Tensor::::new(Some(&[32, 17, 9, 27, 34, 20, 13, 26]), &[1, 2, 2, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[32, 17, 9, 27, 34, 20, 13, 26]), &[1, 2, 2, 2]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// /// // Now test multi channel -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6, 5, 2, 3, 0, 4, -1, 3, 1, 6]), /// &[1, 2, 3, 3], /// ).unwrap()); -/// let k = ValTensor::from_i128_tensor(Tensor::::new( +/// let k = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[5, 1, 1, 1, 5, 2, 1, 1, 5, 3, 1, 1, 5, 4, 1, 1, 5, 1, 1, 1, 5, 2, 1, 1, 5, 3, 1, 1, 5, 4, 1, 1]), /// &[4, 2, 2, 2], /// ).unwrap()); -/// let b = ValTensor::from_i128_tensor(Tensor::::new( +/// let b = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[1, 1, 1, 1]), /// &[4], /// ).unwrap()); /// /// let result =conv(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2]).unwrap(); -/// let expected = Tensor::::new(Some(&[65, 36, 21, 52, 73, 48, 37, 48, 65, 36, 21, 52, 73, 48, 37, 48]), &[1, 4, 2, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[65, 36, 21, 52, 73, 48, 37, 48, 65, 36, 21, 52, 73, 48, 37, 48]), &[1, 4, 2, 2]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` /// pub fn conv< - F: PrimeField + TensorType + PartialOrd + std::hash::Hash + std::marker::Send + std::marker::Sync, + F: PrimeField + + TensorType + + PartialOrd + + std::hash::Hash + + std::marker::Send + + std::marker::Sync + + IntoI64, >( config: &BaseConfig, region: &mut RegionCtx, @@ -3596,7 +3618,7 @@ pub fn conv< } /// Power accumulated layout -pub(crate) fn pow( +pub(crate) fn pow( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -3612,7 +3634,7 @@ pub(crate) fn pow( } /// Rescaled op accumulated layout -pub(crate) fn rescale( +pub(crate) fn rescale( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor], @@ -3634,7 +3656,7 @@ pub(crate) fn rescale } /// Dummy (no contraints) reshape layout -pub(crate) fn reshape( +pub(crate) fn reshape( values: &[ValTensor; 1], new_dims: &[usize], ) -> Result, Box> { @@ -3644,7 +3666,7 @@ pub(crate) fn reshape } /// Dummy (no contraints) move_axis layout -pub(crate) fn move_axis( +pub(crate) fn move_axis( values: &[ValTensor; 1], source: usize, destination: usize, @@ -3655,7 +3677,7 @@ pub(crate) fn move_axis( +pub(crate) fn resize( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -3669,7 +3691,7 @@ pub(crate) fn resize( } /// Slice layout -pub(crate) fn slice( +pub(crate) fn slice( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -3692,7 +3714,7 @@ pub(crate) fn slice( } /// Trilu layout -pub(crate) fn trilu( +pub(crate) fn trilu( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -3716,7 +3738,7 @@ pub(crate) fn trilu( } /// Concat layout -pub(crate) fn concat( +pub(crate) fn concat( values: &[ValTensor], axis: &usize, ) -> Result, Box> { @@ -3728,7 +3750,7 @@ pub(crate) fn concat( } /// Identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon. -pub(crate) fn identity( +pub(crate) fn identity( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -3743,7 +3765,9 @@ pub(crate) fn identity( +pub(crate) fn boolean_identity< + F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64, +>( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -3779,7 +3803,7 @@ pub(crate) fn boolean_identity( +pub(crate) fn downsample( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -3796,7 +3820,9 @@ pub(crate) fn downsample( +pub(crate) fn enforce_equality< + F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64, +>( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 2], @@ -3822,7 +3848,7 @@ pub(crate) fn enforce_equality( +pub(crate) fn range_check( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -3879,7 +3905,7 @@ pub(crate) fn range_check( +pub(crate) fn nonlinearity( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -4006,7 +4032,7 @@ pub(crate) fn nonlinearity( +pub(crate) fn argmax( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -4018,10 +4044,10 @@ pub(crate) fn argmax( .enumerate() // we value the first index in the case of a tie .max_by_key(|(idx, value)| (*value, -(*idx as i64))) - .map(|(idx, _)| idx as i128); + .map(|(idx, _)| idx as i64); let argmax_val: ValTensor = match argmax { None => Tensor::new(Some(&[Value::::unknown()]), &[1])?.into(), - Some(i) => Tensor::new(Some(&[Value::known(i128_to_felt::(i))]), &[1])?.into(), + Some(i) => Tensor::new(Some(&[Value::known(i64_to_felt::(i))]), &[1])?.into(), }; let assigned_argmax: ValTensor = @@ -4042,7 +4068,7 @@ pub(crate) fn argmax( } /// Argmin -pub(crate) fn argmin( +pub(crate) fn argmin( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -4054,10 +4080,10 @@ pub(crate) fn argmin( .enumerate() // we value the first index in the case of a tie .min_by_key(|(idx, value)| (*value, (*idx as i64))) - .map(|(idx, _)| idx as i128); + .map(|(idx, _)| idx as i64); let argmin_val: ValTensor = match argmin { None => Tensor::new(Some(&[Value::::unknown()]), &[1])?.into(), - Some(i) => Tensor::new(Some(&[Value::known(i128_to_felt::(i))]), &[1])?.into(), + Some(i) => Tensor::new(Some(&[Value::known(i64_to_felt::(i))]), &[1])?.into(), }; let assigned_argmin: ValTensor = @@ -4078,7 +4104,7 @@ pub(crate) fn argmin( } /// max layout -pub(crate) fn max( +pub(crate) fn max( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -4088,7 +4114,7 @@ pub(crate) fn max( } /// min layout -pub(crate) fn min( +pub(crate) fn min( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -4096,7 +4122,7 @@ pub(crate) fn min( _sort_ascending(config, region, values)?.get_slice(&[0..1]) } -fn multi_dim_axes_op( +fn multi_dim_axes_op( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -4199,7 +4225,7 @@ fn multi_dim_axes_op( } /// softmax layout -pub(crate) fn softmax_axes( +pub(crate) fn softmax_axes( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -4220,7 +4246,7 @@ pub(crate) fn softmax_axes( +pub(crate) fn percent( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -4263,18 +4289,18 @@ pub(crate) fn percent /// use ezkl::circuit::BaseConfig; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[2, 2, 3, 2, 2, 0]), /// &[2, 3], /// ).unwrap()); /// let result = softmax::(&dummy_config, &mut dummy_region, &[x], 128.0.into(), (128.0 * 128.0).into()).unwrap(); /// // doubles the scale of the input -/// let expected = Tensor::::new(Some(&[2734, 2734, 2756, 2734, 2734, 2691]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[2734, 2734, 2756, 2734, 2734, 2691]), &[2, 3]).unwrap(); /// assert_eq!(result.get_int_evals().unwrap(), expected); /// ``` -pub fn softmax( +pub fn softmax( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], @@ -4309,19 +4335,19 @@ pub fn softmax( /// use ezkl::circuit::BaseConfig; /// /// let dummy_config = BaseConfig::dummy(12, 2); -/// let mut dummy_region = RegionCtx::new_dummy(0,2,true); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,true,true); /// -/// let x = ValTensor::from_i128_tensor(Tensor::::new( +/// let x = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[100, 200, 300, 400, 500, 600]), /// &[2, 3], /// ).unwrap()); -/// let y = ValTensor::from_i128_tensor(Tensor::::new( +/// let y = ValTensor::from_i64_tensor(Tensor::::new( /// Some(&[101, 201, 302, 403, 503, 603]), /// &[2, 3], /// ).unwrap()); /// let result = range_check_percent::(&dummy_config, &mut dummy_region, &[x, y], 1024.0.into(), 1.0).unwrap(); /// ``` -pub fn range_check_percent( +pub fn range_check_percent( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 2], @@ -4346,9 +4372,9 @@ pub fn range_check_percent Range { let range = (max_len - 1) as f64 / 2_f64; - let range = range as i128; + let range = range as i64; (-range, range) } /// Matches a [Op] to an operation in the `tensor::ops` module. - pub(crate) fn f( + pub(crate) fn f( &self, x: &[Tensor], ) -> Result, TensorError> { - let x = x[0].clone().map(|x| felt_to_i128(x)); + let x = x[0].clone().map(|x| felt_to_i64(x)); let res = match &self { LookupOp::Abs => Ok(tensor::ops::abs(&x)?), LookupOp::Ceil { scale } => Ok(tensor::ops::nonlinearities::ceil(&x, scale.into())), @@ -228,13 +228,13 @@ impl LookupOp { } }?; - let output = res.map(|x| i128_to_felt(x)); + let output = res.map(|x| i64_to_felt(x)); Ok(ForwardResult { output }) } } -impl Op for LookupOp { +impl Op for LookupOp { /// Returns a reference to the Any trait. fn as_any(&self) -> &dyn Any { self diff --git a/src/circuit/ops/mod.rs b/src/circuit/ops/mod.rs index 9d4d0fbe4..903b66299 100644 --- a/src/circuit/ops/mod.rs +++ b/src/circuit/ops/mod.rs @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize}; use crate::{ graph::quantize_tensor, - tensor::{self, Tensor, TensorType, ValTensor}, + tensor::{self, IntoI64, Tensor, TensorType, ValTensor}, }; use halo2curves::ff::PrimeField; @@ -27,12 +27,12 @@ pub mod region; /// A struct representing the result of a forward pass. #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] -pub struct ForwardResult { +pub struct ForwardResult { pub(crate) output: Tensor, } /// A trait representing operations that can be represented as constraints in a circuit. -pub trait Op: +pub trait Op: std::fmt::Debug + Send + Sync + Any { /// Returns a string representation of the operation. @@ -71,7 +71,7 @@ pub trait Op: fn as_any(&self) -> &dyn Any; } -impl Clone for Box> { +impl Clone for Box> { fn clone(&self) -> Self { self.clone_dyn() } @@ -122,8 +122,8 @@ impl InputType { *input = T::from_f64(f64_input).unwrap(); } InputType::Int | InputType::TDim => { - let int_input = input.clone().to_i128().unwrap(); - *input = T::from_i128(int_input).unwrap(); + let int_input = input.clone().to_i64().unwrap(); + *input = T::from_i64(int_input).unwrap(); } } } @@ -138,7 +138,7 @@ pub struct Input { pub datum_type: InputType, } -impl Op for Input { +impl Op for Input { fn out_scale(&self, _: Vec) -> Result> { Ok(self.scale) } @@ -193,7 +193,7 @@ impl Op for Input #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] pub struct Unknown; -impl Op for Unknown { +impl Op for Unknown { fn out_scale(&self, _: Vec) -> Result> { Ok(0) } @@ -220,7 +220,7 @@ impl Op for Unknow /// #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Constant { +pub struct Constant { /// pub quantized_values: Tensor, /// @@ -230,7 +230,7 @@ pub struct Constant { pub pre_assigned_val: Option>, } -impl Constant { +impl Constant { /// pub fn new(quantized_values: Tensor, raw_values: Tensor) -> Self { Self { @@ -263,7 +263,8 @@ impl< + PartialOrd + std::hash::Hash + Serialize - + for<'de> Deserialize<'de>, + + for<'de> Deserialize<'de> + + IntoI64, > Op for Constant { fn as_any(&self) -> &dyn Any { diff --git a/src/circuit/ops/poly.rs b/src/circuit/ops/poly.rs index ef91734d7..e8ed1a80b 100644 --- a/src/circuit/ops/poly.rs +++ b/src/circuit/ops/poly.rs @@ -97,7 +97,8 @@ impl< + PartialOrd + std::hash::Hash + Serialize - + for<'de> Deserialize<'de>, + + for<'de> Deserialize<'de> + + IntoI64, > Op for PolyOp { /// Returns a reference to the Any trait. diff --git a/src/circuit/ops/region.rs b/src/circuit/ops/region.rs index 65f58e72e..6b7567a58 100644 --- a/src/circuit/ops/region.rs +++ b/src/circuit/ops/region.rs @@ -9,7 +9,7 @@ use halo2_proofs::{ plonk::{Error, Selector}, }; use halo2curves::ff::PrimeField; -use portable_atomic::AtomicI128 as AtomicInt; +use portable_atomic::AtomicI64 as AtomicInt; use std::{ cell::RefCell, collections::{HashMap, HashSet}, @@ -133,10 +133,11 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Ha shuffle_index: ShuffleIndex, used_lookups: HashSet, used_range_checks: HashSet, - max_lookup_inputs: i128, - min_lookup_inputs: i128, - max_range_size: i128, + max_lookup_inputs: i64, + min_lookup_inputs: i64, + max_range_size: i64, witness_gen: bool, + check_lookup_range: bool, assigned_constants: ConstantsMap, } @@ -191,6 +192,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a self.witness_gen } + /// + pub fn check_lookup_range(&self) -> bool { + self.check_lookup_range + } + /// Create a new region context pub fn new(region: Region<'a, F>, row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> { let region = Some(RefCell::new(region)); @@ -209,6 +215,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a min_lookup_inputs: 0, max_range_size: 0, witness_gen: true, + check_lookup_range: true, assigned_constants: HashMap::new(), } } @@ -246,12 +253,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a min_lookup_inputs: 0, max_range_size: 0, witness_gen: false, + check_lookup_range: false, assigned_constants: HashMap::new(), } } /// Create a new region context - pub fn new_dummy(row: usize, num_inner_cols: usize, witness_gen: bool) -> RegionCtx<'a, F> { + pub fn new_dummy( + row: usize, + num_inner_cols: usize, + witness_gen: bool, + check_lookup_range: bool, + ) -> RegionCtx<'a, F> { let region = None; let linear_coord = row * num_inner_cols; @@ -268,6 +281,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a min_lookup_inputs: 0, max_range_size: 0, witness_gen, + check_lookup_range, assigned_constants: HashMap::new(), } } @@ -278,6 +292,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a linear_coord: usize, num_inner_cols: usize, witness_gen: bool, + check_lookup_range: bool, ) -> RegionCtx<'a, F> { let region = None; RegionCtx { @@ -293,6 +308,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a min_lookup_inputs: 0, max_range_size: 0, witness_gen, + check_lookup_range, assigned_constants: HashMap::new(), } } @@ -364,6 +380,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a starting_linear_coord, self.num_inner_cols, self.witness_gen, + self.check_lookup_range, ); let res = inner_loop_function(idx, &mut local_reg); // we update the offset and constants @@ -546,17 +563,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a } /// max lookup inputs - pub fn max_lookup_inputs(&self) -> i128 { + pub fn max_lookup_inputs(&self) -> i64 { self.max_lookup_inputs } /// min lookup inputs - pub fn min_lookup_inputs(&self) -> i128 { + pub fn min_lookup_inputs(&self) -> i64 { self.min_lookup_inputs } /// max range check - pub fn max_range_size(&self) -> i128 { + pub fn max_range_size(&self) -> i64 { self.max_range_size } diff --git a/src/circuit/table.rs b/src/circuit/table.rs index 82388665c..c965dd351 100644 --- a/src/circuit/table.rs +++ b/src/circuit/table.rs @@ -11,17 +11,17 @@ use maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator}; use crate::{ circuit::CircuitError, - fieldutils::i128_to_felt, - tensor::{Tensor, TensorType}, + fieldutils::i64_to_felt, + tensor::{IntoI64, Tensor, TensorType}, }; use crate::circuit::lookup::LookupOp; /// The range of the lookup table. -pub type Range = (i128, i128); +pub type Range = (i64, i64); /// The safety factor for the range of the lookup table. -pub const RANGE_MULTIPLIER: i128 = 2; +pub const RANGE_MULTIPLIER: i64 = 2; /// The safety factor offset for the number of rows in the lookup table. pub const RESERVED_BLINDING_ROWS_PAD: usize = 3; @@ -96,21 +96,21 @@ pub struct Table { _marker: PhantomData, } -impl Table { +impl Table { /// get column index given input pub fn get_col_index(&self, input: F) -> F { // range is split up into chunks of size col_size, find the chunk that input is in let chunk = - (crate::fieldutils::felt_to_i128(input) - self.range.0).abs() / (self.col_size as i128); + (crate::fieldutils::felt_to_i64(input) - self.range.0).abs() / (self.col_size as i64); - i128_to_felt(chunk) + i64_to_felt(chunk) } /// get first_element of column pub fn get_first_element(&self, chunk: usize) -> (F, F) { - let chunk = chunk as i128; + let chunk = chunk as i64; // we index from 1 to prevent soundness issues - let first_element = i128_to_felt(chunk * (self.col_size as i128) + self.range.0); + let first_element = i64_to_felt(chunk * (self.col_size as i64) + self.range.0); let op_f = self .nonlinearity .f(&[Tensor::from(vec![first_element].into_iter())]) @@ -130,12 +130,12 @@ impl Table { } /// -pub fn num_cols_required(range_len: i128, col_size: usize) -> usize { +pub fn num_cols_required(range_len: i64, col_size: usize) -> usize { // number of cols needed to store the range - (range_len / (col_size as i128)) as usize + 1 + (range_len / (col_size as i64)) as usize + 1 } -impl Table { +impl Table { /// Configures the table. pub fn configure( cs: &mut ConstraintSystem, @@ -202,7 +202,7 @@ impl Table { let smallest = self.range.0; let largest = self.range.1; - let inputs: Tensor = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x)); + let inputs: Tensor = Tensor::from(smallest..=largest).map(|x| i64_to_felt(x)); let evals = self.nonlinearity.f(&[inputs.clone()])?; let chunked_inputs = inputs.chunks(self.col_size); @@ -272,12 +272,12 @@ pub struct RangeCheck { _marker: PhantomData, } -impl RangeCheck { +impl RangeCheck { /// get first_element of column pub fn get_first_element(&self, chunk: usize) -> F { - let chunk = chunk as i128; + let chunk = chunk as i64; // we index from 1 to prevent soundness issues - i128_to_felt(chunk * (self.col_size as i128) + self.range.0) + i64_to_felt(chunk * (self.col_size as i64) + self.range.0) } /// @@ -294,13 +294,13 @@ impl RangeCheck { pub fn get_col_index(&self, input: F) -> F { // range is split up into chunks of size col_size, find the chunk that input is in let chunk = - (crate::fieldutils::felt_to_i128(input) - self.range.0).abs() / (self.col_size as i128); + (crate::fieldutils::felt_to_i64(input) - self.range.0).abs() / (self.col_size as i64); - i128_to_felt(chunk) + i64_to_felt(chunk) } } -impl RangeCheck { +impl RangeCheck { /// Configures the table. pub fn configure(cs: &mut ConstraintSystem, range: Range, logrows: usize) -> RangeCheck { log::debug!("range check range: {:?}", range); @@ -350,7 +350,7 @@ impl RangeCheck { let smallest = self.range.0; let largest = self.range.1; - let inputs: Tensor = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x)); + let inputs: Tensor = Tensor::from(smallest..=largest).map(|x| i64_to_felt(x)); let chunked_inputs = inputs.chunks(self.col_size); self.is_assigned = true; diff --git a/src/commands.rs b/src/commands.rs index 28ee046ee..712c4121c 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -345,7 +345,7 @@ pub enum Commands { target: CalibrationTarget, /// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower #[arg(long, default_value = DEFAULT_LOOKUP_SAFETY_MARGIN)] - lookup_safety_margin: i128, + lookup_safety_margin: i64, /// Optional scales to specifically try for calibration. Example, --scales 0,4 #[arg(long, value_delimiter = ',', allow_hyphen_values = true)] scales: Option>, diff --git a/src/eth.rs b/src/eth.rs index 3be1476ba..b12a36784 100644 --- a/src/eth.rs +++ b/src/eth.rs @@ -424,7 +424,7 @@ pub async fn setup_test_contract( let input = input.to_float() as f32; let decimal_places = count_decimal_places(input) as u8; let scaled_by_decimals = input * f32::powf(10., decimal_places.into()); - scaled_by_decimals_data.push(I256::from(scaled_by_decimals as i128)); + scaled_by_decimals_data.push(I256::from(scaled_by_decimals as i64)); decimals.push(decimal_places); } else if input.is_field() { let input = input.to_field(0); diff --git a/src/execute.rs b/src/execute.rs index ce44ea531..2ef86af62 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -682,6 +682,7 @@ pub(crate) fn gen_witness( vk.as_ref(), Some(&srs), true, + true, )? } Commitments::IPA => { @@ -696,15 +697,22 @@ pub(crate) fn gen_witness( vk.as_ref(), Some(&srs), true, + true, )? } } } else { warn!("SRS for poly commit does not exist (will be ignored)"); - circuit.forward::>(&mut input, vk.as_ref(), None, true)? + circuit.forward::>( + &mut input, + vk.as_ref(), + None, + true, + true, + )? } } else { - circuit.forward::>(&mut input, vk.as_ref(), None, true)? + circuit.forward::>(&mut input, vk.as_ref(), None, true, true)? }; // print each variable tuple (symbol, value) as symbol=value @@ -887,7 +895,7 @@ pub(crate) fn calibrate( data: PathBuf, settings_path: PathBuf, target: CalibrationTarget, - lookup_safety_margin: i128, + lookup_safety_margin: i64, scales: Option>, scale_rebase_multiplier: Vec, only_range_check_rebase: bool, @@ -1003,6 +1011,7 @@ pub(crate) fn calibrate( param_scale, scale_rebase_multiplier, div_rebasing, + lookup_range: (i64::MIN, i64::MAX), ..settings.run_args.clone() }; @@ -1038,7 +1047,13 @@ pub(crate) fn calibrate( .map_err(|e| format!("failed to load circuit inputs: {}", e))?; let forward_res = circuit - .forward::>(&mut data.clone(), None, None, true) + .forward::>( + &mut data.clone(), + None, + None, + true, + false, + ) .map_err(|e| format!("failed to forward: {}", e))?; // push result to the hashmap @@ -1053,7 +1068,7 @@ pub(crate) fn calibrate( match forward_res { Ok(_) => (), - // typically errors will be due to the circuit overflowing the i128 limit + // typically errors will be due to the circuit overflowing the i64 limit Err(e) => { error!("forward pass failed: {:?}", e); pb.inc(1); diff --git a/src/fieldutils.rs b/src/fieldutils.rs index d3a606d83..0bfc43b1f 100644 --- a/src/fieldutils.rs +++ b/src/fieldutils.rs @@ -11,8 +11,8 @@ pub fn i32_to_felt(x: i32) -> F { } } -/// Converts an i128 to a PrimeField element. -pub fn i128_to_felt(x: i128) -> F { +/// Converts an i64 to a PrimeField element. +pub fn i64_to_felt(x: i64) -> F { if x >= 0 { F::from_u128(x as u128) } else { @@ -37,7 +37,7 @@ pub fn felt_to_i32(x: F) -> i32 { /// Converts a PrimeField element to an f64. pub fn felt_to_f64(x: F) -> f64 { - if x > F::from_u128(i128::MAX as u128) { + if x > F::from_u128(i64::MAX as u128) { let rep = (-x).to_repr(); let negtmp: &[u8] = rep.as_ref(); let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap()); @@ -50,18 +50,18 @@ pub fn felt_to_f64(x: F) -> f64 { } } -/// Converts a PrimeField element to an i128. -pub fn felt_to_i128(x: F) -> i128 { - if x > F::from_u128(i128::MAX as u128) { +/// Converts a PrimeField element to an i64. +pub fn felt_to_i64(x: F) -> i64 { + if x > F::from_u128(i64::MAX as u128) { let rep = (-x).to_repr(); let negtmp: &[u8] = rep.as_ref(); let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap()); - -(lower_128 as i128) + -(lower_128 as i64) } else { let rep = (x).to_repr(); let tmp: &[u8] = rep.as_ref(); let lower_128: u128 = u128::from_le_bytes(tmp[..16].try_into().unwrap()); - lower_128 as i128 + lower_128 as i64 } } @@ -79,10 +79,10 @@ mod test { let res: F = i32_to_felt(2_i32.pow(17)); assert_eq!(res, F::from(131072)); - let res: F = i128_to_felt(-15i128); + let res: F = i64_to_felt(-15i64); assert_eq!(res, -F::from(15)); - let res: F = i128_to_felt(2_i128.pow(17)); + let res: F = i64_to_felt(2_i64.pow(17)); assert_eq!(res, F::from(131072)); } @@ -96,10 +96,10 @@ mod test { } #[test] - fn felttoi128() { - for x in -(2i128.pow(20))..(2i128.pow(20)) { - let fieldx: F = i128_to_felt::(x); - let xf: i128 = felt_to_i128::(fieldx); + fn felttoi64() { + for x in -(2i64.pow(20))..(2i64.pow(20)) { + let fieldx: F = i64_to_felt::(x); + let xf: i64 = felt_to_i64::(fieldx); assert_eq!(x, xf); } } diff --git a/src/graph/input.rs b/src/graph/input.rs index e31211071..ff4ade430 100644 --- a/src/graph/input.rs +++ b/src/graph/input.rs @@ -1,7 +1,7 @@ use super::quantize_float; use super::GraphError; use crate::circuit::InputType; -use crate::fieldutils::i128_to_felt; +use crate::fieldutils::i64_to_felt; #[cfg(not(target_arch = "wasm32"))] use crate::tensor::Tensor; use crate::EZKL_BUF_CAPACITY; @@ -128,7 +128,7 @@ impl FileSourceInner { /// Convert to a field element pub fn to_field(&self, scale: crate::Scale) -> Fp { match self { - FileSourceInner::Float(f) => i128_to_felt(quantize_float(f, 0.0, scale).unwrap()), + FileSourceInner::Float(f) => i64_to_felt(quantize_float(f, 0.0, scale).unwrap()), FileSourceInner::Bool(f) => { if *f { Fp::one() @@ -150,7 +150,7 @@ impl FileSourceInner { 0.0 } } - FileSourceInner::Field(f) => crate::fieldutils::felt_to_i128(*f) as f64, + FileSourceInner::Field(f) => crate::fieldutils::felt_to_i64(*f) as f64, } } } diff --git a/src/graph/mod.rs b/src/graph/mod.rs index c744208bf..495dcc5ea 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -62,13 +62,13 @@ pub use vars::*; use crate::pfsys::field_to_string; /// The safety factor for the range of the lookup table. -pub const RANGE_MULTIPLIER: i128 = 2; +pub const RANGE_MULTIPLIER: i64 = 2; /// The maximum number of columns in a lookup table. pub const MAX_NUM_LOOKUP_COLS: usize = 12; /// Max representation of a lookup table input -pub const MAX_LOOKUP_ABS: i128 = (MAX_NUM_LOOKUP_COLS as i128) * 2_i128.pow(MAX_PUBLIC_SRS); +pub const MAX_LOOKUP_ABS: i64 = (MAX_NUM_LOOKUP_COLS as i64) * 2_i64.pow(MAX_PUBLIC_SRS); #[cfg(not(target_arch = "wasm32"))] lazy_static! { @@ -175,11 +175,11 @@ pub struct GraphWitness { /// Any hashes of outputs generated during the forward pass pub processed_outputs: Option, /// max lookup input - pub max_lookup_inputs: i128, + pub max_lookup_inputs: i64, /// max lookup input - pub min_lookup_inputs: i128, + pub min_lookup_inputs: i64, /// max range check size - pub max_range_size: i128, + pub max_range_size: i64, } impl GraphWitness { @@ -1098,14 +1098,14 @@ impl GraphCircuit { Ok(data) } - fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i128) -> Range { + fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i64) -> Range { ( lookup_safety_margin * min_max_lookup.0, lookup_safety_margin * min_max_lookup.1, ) } - fn calc_num_cols(range_len: i128, max_logrows: u32) -> usize { + fn calc_num_cols(range_len: i64, max_logrows: u32) -> usize { let max_col_size = Table::::cal_col_size(max_logrows as usize, RESERVED_BLINDING_ROWS); num_cols_required(range_len, max_col_size) } @@ -1113,7 +1113,7 @@ impl GraphCircuit { fn table_size_logrows( &self, safe_lookup_range: Range, - max_range_size: i128, + max_range_size: i64, ) -> Result> { // pick the range with the largest absolute size safe_lookup_range or max_range_size let safe_range = std::cmp::max( @@ -1132,9 +1132,9 @@ impl GraphCircuit { pub fn calc_min_logrows( &mut self, min_max_lookup: Range, - max_range_size: i128, + max_range_size: i64, max_logrows: Option, - lookup_safety_margin: i128, + lookup_safety_margin: i64, ) -> Result<(), Box> { // load the max logrows let max_logrows = max_logrows.unwrap_or(MAX_PUBLIC_SRS); @@ -1228,7 +1228,7 @@ impl GraphCircuit { &self, k: u32, safe_lookup_range: Range, - max_range_size: i128, + max_range_size: i64, ) -> bool { // if num cols is too large then the extended k is too large if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS @@ -1287,6 +1287,7 @@ impl GraphCircuit { vk: Option<&VerifyingKey>, srs: Option<&Scheme::ParamsProver>, witness_gen: bool, + check_lookup: bool, ) -> Result> { let original_inputs = inputs.to_vec(); @@ -1335,7 +1336,7 @@ impl GraphCircuit { let mut model_results = self.model() - .forward(inputs, &self.settings().run_args, witness_gen)?; + .forward(inputs, &self.settings().run_args, witness_gen, check_lookup)?; if visibility.output.requires_processing() { let module_outlets = visibility.output.overwrites_inputs(); diff --git a/src/graph/model.rs b/src/graph/model.rs index 720d86bde..c30870047 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -65,11 +65,11 @@ pub struct ForwardResult { /// The outputs of the forward pass. pub outputs: Vec>, /// The maximum value of any input to a lookup operation. - pub max_lookup_inputs: i128, + pub max_lookup_inputs: i64, /// The minimum value of any input to a lookup operation. - pub min_lookup_inputs: i128, + pub min_lookup_inputs: i64, /// The max range check size - pub max_range_size: i128, + pub max_range_size: i64, } impl From for ForwardResult { @@ -117,11 +117,11 @@ pub struct DummyPassRes { /// range checks pub range_checks: HashSet, /// max lookup inputs - pub max_lookup_inputs: i128, + pub max_lookup_inputs: i64, /// min lookup inputs - pub min_lookup_inputs: i128, + pub min_lookup_inputs: i64, /// min range check - pub max_range_size: i128, + pub max_range_size: i64, /// outputs pub outputs: Vec>, } @@ -538,7 +538,7 @@ impl Model { }) .collect::, Box>>()?; - let res = self.dummy_layout(run_args, &inputs, false)?; + let res = self.dummy_layout(run_args, &inputs, false, false)?; // if we're using percentage tolerance, we need to add the necessary range check ops for it. @@ -582,12 +582,13 @@ impl Model { model_inputs: &[Tensor], run_args: &RunArgs, witness_gen: bool, + check_lookup: bool, ) -> Result> { let valtensor_inputs: Vec> = model_inputs .iter() .map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into()) .collect(); - let res = self.dummy_layout(run_args, &valtensor_inputs, witness_gen)?; + let res = self.dummy_layout(run_args, &valtensor_inputs, witness_gen, check_lookup)?; Ok(res.into()) } @@ -1392,6 +1393,7 @@ impl Model { run_args: &RunArgs, inputs: &[ValTensor], witness_gen: bool, + check_lookup: bool, ) -> Result> { debug!("calculating num of constraints using dummy model layout..."); @@ -1410,7 +1412,8 @@ impl Model { vars: ModelVars::new_dummy(), }; - let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, witness_gen); + let mut region = + RegionCtx::new_dummy(0, run_args.num_inner_cols, witness_gen, check_lookup); let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?; diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index 756f60126..2415713f3 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -52,16 +52,16 @@ use tract_onnx::tract_hir::{ /// * `dims` - the dimensionality of the resulting [Tensor]. /// * `shift` - offset used in the fixed point representation. /// * `scale` - `2^scale` used in the fixed point representation. -pub fn quantize_float(elem: &f64, shift: f64, scale: crate::Scale) -> Result { +pub fn quantize_float(elem: &f64, shift: f64, scale: crate::Scale) -> Result { let mult = scale_to_multiplier(scale); - let max_value = ((i128::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation + let max_value = ((i64::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation if *elem > max_value { return Err(TensorError::SigBitTruncationError); } // we parallelize the quantization process as it seems to be quite slow at times - let scaled = (mult * *elem + shift).round() as i128; + let scaled = (mult * *elem + shift).round() as i64; Ok(scaled) } @@ -72,7 +72,7 @@ pub fn quantize_float(elem: &f64, shift: f64, scale: crate::Scale) -> Result f64 { - let int_rep = crate::fieldutils::felt_to_i128(felt); + let int_rep = crate::fieldutils::felt_to_i64(felt); let multiplier = scale_to_multiplier(scale); int_rep as f64 / multiplier - shift } @@ -1475,7 +1475,7 @@ pub fn quantize_tensor( visibility: &Visibility, ) -> Result, Box> { let mut value: Tensor = const_value.par_enum_map(|_, x| { - Ok::<_, TensorError>(crate::fieldutils::i128_to_felt::(quantize_float( + Ok::<_, TensorError>(crate::fieldutils::i64_to_felt::(quantize_float( &(x).into(), 0.0, scale, diff --git a/src/lib.rs b/src/lib.rs index 62838b7f4..77607bee5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,6 +23,7 @@ )] // we allow this for our dynamic range based indexing scheme #![allow(clippy::single_range_in_vec_init)] +#![feature(stmt_expr_attributes)] //! A library for turning computational graphs, such as neural networks, into ZK-circuits. //! @@ -189,7 +190,7 @@ pub struct RunArgs { #[arg(long, default_value = "1")] pub scale_rebase_multiplier: u32, /// The min and max elements in the lookup table input column - #[arg(short = 'B', long, value_parser = parse_key_val::, default_value = "-32768->32768")] + #[arg(short = 'B', long, value_parser = parse_key_val::, default_value = "-32768->32768")] pub lookup_range: Range, /// The log_2 number of rows #[arg(short = 'K', long, default_value = "17")] diff --git a/src/python.rs b/src/python.rs index 8586cda25..14b6ee63c 100644 --- a/src/python.rs +++ b/src/python.rs @@ -6,7 +6,7 @@ use crate::circuit::modules::poseidon::{ use crate::circuit::modules::Module; use crate::circuit::{CheckMode, Tolerance}; use crate::commands::*; -use crate::fieldutils::{felt_to_i128, i128_to_felt}; +use crate::fieldutils::{felt_to_i64, i64_to_felt}; use crate::graph::modules::POSEIDON_LEN_GRAPH; use crate::graph::TestDataSource; use crate::graph::{ @@ -332,9 +332,9 @@ fn felt_to_big_endian(felt: PyFelt) -> PyResult { #[pyfunction(signature = ( felt, ))] -fn felt_to_int(felt: PyFelt) -> PyResult { +fn felt_to_int(felt: PyFelt) -> PyResult { let felt = crate::pfsys::string_to_field::(&felt); - let int_rep = felt_to_i128(felt); + let int_rep = felt_to_i64(felt); Ok(int_rep) } @@ -358,7 +358,7 @@ fn felt_to_int(felt: PyFelt) -> PyResult { ))] fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult { let felt = crate::pfsys::string_to_field::(&felt); - let int_rep = felt_to_i128(felt); + let int_rep = felt_to_i64(felt); let multiplier = scale_to_multiplier(scale); let float_rep = int_rep as f64 / multiplier; Ok(float_rep) @@ -386,7 +386,7 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult { fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult { let int_rep = quantize_float(&input, 0.0, scale) .map_err(|_| PyIOError::new_err("Failed to quantize input"))?; - let felt = i128_to_felt(int_rep); + let felt = i64_to_felt(int_rep); Ok(crate::pfsys::field_to_string::(&felt)) } @@ -889,7 +889,7 @@ fn calibrate_settings( model: PathBuf, settings: PathBuf, target: CalibrationTarget, - lookup_safety_margin: i128, + lookup_safety_margin: i64, scales: Option>, scale_rebase_multiplier: Vec, max_logrows: Option, diff --git a/src/tensor/metal/tensor_ops.air b/src/tensor/metal/tensor_ops.air new file mode 100644 index 000000000..a985df504 Binary files /dev/null and b/src/tensor/metal/tensor_ops.air differ diff --git a/src/tensor/metal/tensor_ops.metal b/src/tensor/metal/tensor_ops.metal new file mode 100644 index 000000000..d8ced1e95 --- /dev/null +++ b/src/tensor/metal/tensor_ops.metal @@ -0,0 +1,31 @@ +[[kernel]] +void add( + constant long *inA [[buffer(0)]], + constant long *inB [[buffer(1)]], + device long *result [[buffer(2)]], + uint index [[thread_position_in_grid]]) +{ + result[index] = inA[index] + inB[index]; +} + + +[[kernel]] +void sub( + constant long *inA [[buffer(0)]], + constant long *inB [[buffer(1)]], + device long *result [[buffer(2)]], + uint index [[thread_position_in_grid]]) +{ + result[index] = inA[index] - inB[index]; +} + + +[[kernel]] +void mul( + constant long *inA [[buffer(0)]], + constant long *inB [[buffer(1)]], + device long *result [[buffer(2)]], + uint index [[thread_position_in_grid]]) +{ + result[index] = inA[index] * inB[index]; +} \ No newline at end of file diff --git a/src/tensor/metal/tensor_ops.metallib b/src/tensor/metal/tensor_ops.metallib new file mode 100644 index 000000000..138584801 Binary files /dev/null and b/src/tensor/metal/tensor_ops.metallib differ diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 6004d3658..3b1a77b09 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -5,7 +5,7 @@ pub mod val; /// A wrapper around a tensor of Halo2 Value types. pub mod var; -use halo2curves::ff::PrimeField; +use halo2curves::{bn256::Fr, ff::PrimeField}; use maybe_rayon::{ prelude::{ IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, @@ -17,9 +17,12 @@ use serde::{Deserialize, Serialize}; pub use val::*; pub use var::*; +#[cfg(feature = "metal")] +use instant::Instant; + use crate::{ circuit::utils, - fieldutils::{felt_to_i32, i128_to_felt, i32_to_felt}, + fieldutils::{felt_to_i32, felt_to_i64, i32_to_felt, i64_to_felt}, graph::Visibility, }; @@ -30,12 +33,18 @@ use halo2_proofs::{ poly::Rotation, }; use itertools::Itertools; +#[cfg(feature = "metal")] +use metal::{Device, MTLResourceOptions, MTLSize}; use std::error::Error; use std::fmt::Debug; use std::iter::Iterator; use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub}; use std::{cmp::max, ops::Rem}; use thiserror::Error; + +#[cfg(feature = "metal")] +use std::collections::HashMap; + /// A wrapper for tensor related errors. #[derive(Debug, Error)] pub enum TensorError { @@ -65,6 +74,28 @@ pub enum TensorError { Overflow(String), } +#[cfg(feature = "metal")] +const LIB_DATA: &[u8] = include_bytes!("metal/tensor_ops.metallib"); + +#[cfg(feature = "metal")] +lazy_static::lazy_static! { + static ref DEVICE: Device = Device::system_default().expect("no device found"); + + static ref LIB: metal::Library = DEVICE.new_library_with_data(LIB_DATA).unwrap(); + + static ref QUEUE: metal::CommandQueue = DEVICE.new_command_queue(); + + static ref PIPELINES: HashMap = { + let mut map = HashMap::new(); + for name in ["add", "sub", "mul"] { + let function = LIB.get_function(name, None).unwrap(); + let pipeline = DEVICE.new_compute_pipeline_state_with_function(&function).unwrap(); + map.insert(name.to_string(), pipeline); + } + map + }; +} + /// The (inner) type of tensor elements. pub trait TensorType: Clone + Debug + 'static { /// Returns the zero value. @@ -145,7 +176,7 @@ impl TensorType for f64 { } tensor_type!(bool, Bool, false, true); -tensor_type!(i128, Int128, 0, 1); +tensor_type!(i64, Int64, 0, 1); tensor_type!(i32, Int32, 0, 1); tensor_type!(usize, USize, 0, 1); tensor_type!((), Empty, (), ()); @@ -311,6 +342,94 @@ impl DerefMut for Tensor { self.inner.deref_mut() } } +/// Convert to i64 trait +pub trait IntoI64 { + /// Convert to i64 + fn into_i64(self) -> i64; + + /// From i64 + fn from_i64(i: i64) -> Self; +} + +impl IntoI64 for i64 { + fn into_i64(self) -> i64 { + self + } + fn from_i64(i: i64) -> i64 { + i + } +} + +impl IntoI64 for i32 { + fn into_i64(self) -> i64 { + self as i64 + } + fn from_i64(i: i64) -> Self { + i as i32 + } +} + +impl IntoI64 for usize { + fn into_i64(self) -> i64 { + self as i64 + } + fn from_i64(i: i64) -> Self { + i as usize + } +} + +impl IntoI64 for f32 { + fn into_i64(self) -> i64 { + self as i64 + } + fn from_i64(i: i64) -> Self { + i as f32 + } +} + +impl IntoI64 for f64 { + fn into_i64(self) -> i64 { + self as i64 + } + fn from_i64(i: i64) -> Self { + i as f64 + } +} + +impl IntoI64 for () { + fn into_i64(self) -> i64 { + 0 + } + fn from_i64(_: i64) -> Self { + () + } +} + +impl IntoI64 for Fr { + fn into_i64(self) -> i64 { + felt_to_i64(self) + } + fn from_i64(i: i64) -> Self { + i64_to_felt::(i) + } +} + +impl IntoI64 for Value { + fn into_i64(self) -> i64 { + let mut res = vec![]; + self.map(|x| res.push(x.into_i64())); + + if res.len() == 0 { + 0 + } else { + res[0] + } + } + + fn from_i64(i: i64) -> Self { + Value::known(F::from_i64(i)) + } +} impl PartialEq for Tensor { fn eq(&self, other: &Tensor) -> bool { @@ -427,10 +546,10 @@ impl From> for Tensor> } } -impl From> for Tensor> { - fn from(t: Tensor) -> Tensor> { +impl From> for Tensor> { + fn from(t: Tensor) -> Tensor> { let mut ta: Tensor> = - Tensor::from((0..t.len()).map(|i| Value::known(i128_to_felt::(t[i])))); + Tensor::from((0..t.len()).map(|i| Value::known(i64_to_felt::(t[i])))); // safe to unwrap as we know the dims are correct ta.reshape(t.dims()).unwrap(); ta @@ -1217,6 +1336,97 @@ impl Tensor { } } +#[cfg(feature = "metal")] +#[allow(unsafe_code)] +/// Perform a tensor operation on the GPU using Metal. +pub fn metal_tensor_op( + v: &Tensor, + w: &Tensor, + op: &str, +) -> Tensor { + assert_eq!(v.dims(), w.dims()); + + log::trace!("------------------------------------------------"); + + let start = Instant::now(); + let v = v + .par_enum_map(|_, x| Ok::<_, TensorError>(x.into_i64())) + .unwrap(); + let w = w + .par_enum_map(|_, x| Ok::<_, TensorError>(x.into_i64())) + .unwrap(); + log::trace!("Time to map tensors: {:?}", start.elapsed()); + + objc::rc::autoreleasepool(|| { + // create function pipeline. + // this compiles the function, so a pipline can't be created in performance sensitive code. + + let pipeline = &PIPELINES[op]; + + let length = v.len() as u64; + let size = length * core::mem::size_of::() as u64; + assert_eq!(v.len(), w.len()); + + let start = Instant::now(); + + let buffer_a = DEVICE.new_buffer_with_data( + unsafe { std::mem::transmute(v.as_ptr()) }, + size, + MTLResourceOptions::StorageModeShared, + ); + let buffer_b = DEVICE.new_buffer_with_data( + unsafe { std::mem::transmute(w.as_ptr()) }, + size, + MTLResourceOptions::StorageModeShared, + ); + let buffer_result = DEVICE.new_buffer( + size, // the operation will return an array with the same size. + MTLResourceOptions::StorageModeShared, + ); + + log::trace!("Time to load buffers: {:?}", start.elapsed()); + + // for sending commands, a command buffer is needed. + let start = Instant::now(); + let command_buffer = QUEUE.new_command_buffer(); + log::trace!("Time to load command buffer: {:?}", start.elapsed()); + // to write commands into a buffer an encoder is needed, in our case a compute encoder. + let start = Instant::now(); + let compute_encoder = command_buffer.new_compute_command_encoder(); + compute_encoder.set_compute_pipeline_state(&pipeline); + compute_encoder.set_buffers( + 0, + &[Some(&buffer_a), Some(&buffer_b), Some(&buffer_result)], + &[0; 3], + ); + log::trace!("Time to load compute encoder: {:?}", start.elapsed()); + + // specify thread count and organization + let start = Instant::now(); + let grid_size = MTLSize::new(length, 1, 1); + let threadgroup_size = MTLSize::new(length, 1, 1); + compute_encoder.dispatch_threads(grid_size, threadgroup_size); + log::trace!("Time to dispatch threads: {:?}", start.elapsed()); + + // end encoding and execute commands + let start = Instant::now(); + compute_encoder.end_encoding(); + command_buffer.commit(); + + command_buffer.wait_until_completed(); + log::trace!("Time to commit: {:?}", start.elapsed()); + + let start = Instant::now(); + let ptr = buffer_result.contents() as *const i64; + let len = buffer_result.length() as usize / std::mem::size_of::(); + let slice = unsafe { core::slice::from_raw_parts(ptr, len) }; + let res = Tensor::new(Some(&slice.to_vec()), &v.dims()).unwrap(); + log::trace!("Time to get result: {:?}", start.elapsed()); + + res.map(|x| T::from_i64(x)) + }) +} + impl Tensor> { /// Flattens a tensor of tensors /// ``` @@ -1238,7 +1448,9 @@ impl Tensor> { } } -impl + std::marker::Send + std::marker::Sync> Add for Tensor { +impl + std::marker::Send + std::marker::Sync + IntoI64> Add + for Tensor +{ type Output = Result, TensorError>; /// Adds tensors. /// # Arguments @@ -1288,14 +1500,24 @@ impl + std::marker::Send + std::marker::Sync> Ad /// ``` fn add(self, rhs: Self) -> Self::Output { let broadcasted_shape = get_broadcasted_shape(self.dims(), rhs.dims()).unwrap(); - let mut lhs = self.expand(&broadcasted_shape).unwrap(); + let lhs = self.expand(&broadcasted_shape).unwrap(); let rhs = rhs.expand(&broadcasted_shape).unwrap(); - lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| { - *o = o.clone() + r; - }); + #[cfg(feature = "metal")] + let res = metal_tensor_op(&lhs, &rhs, "add"); + + #[cfg(not(feature = "metal"))] + let res = { + let mut res: Tensor = lhs + .par_iter() + .zip(rhs) + .map(|(o, r)| o.clone() + r) + .collect(); + res.reshape(&broadcasted_shape).unwrap(); + res + }; - Ok(lhs) + Ok(res) } } @@ -1318,6 +1540,7 @@ impl + std::marker::Send + std::marker::Sync> Ne /// ``` fn neg(self) -> Self { let mut output = self; + output.par_iter_mut().for_each(|x| { *x = x.clone().neg(); }); @@ -1325,7 +1548,9 @@ impl + std::marker::Send + std::marker::Sync> Ne } } -impl + std::marker::Send + std::marker::Sync> Sub for Tensor { +impl + std::marker::Send + std::marker::Sync + IntoI64> Sub + for Tensor +{ type Output = Result, TensorError>; /// Subtracts tensors. /// # Arguments @@ -1376,18 +1601,30 @@ impl + std::marker::Send + std::marker::Sync> Su /// ``` fn sub(self, rhs: Self) -> Self::Output { let broadcasted_shape = get_broadcasted_shape(self.dims(), rhs.dims()).unwrap(); - let mut lhs = self.expand(&broadcasted_shape).unwrap(); + let lhs = self.expand(&broadcasted_shape).unwrap(); let rhs = rhs.expand(&broadcasted_shape).unwrap(); - lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| { - *o = o.clone() - r; - }); + #[cfg(feature = "metal")] + let res = metal_tensor_op(&lhs, &rhs, "sub"); + + #[cfg(not(feature = "metal"))] + let res = { + let mut res: Tensor = lhs + .par_iter() + .zip(rhs) + .map(|(o, r)| o.clone() - r) + .collect(); + res.reshape(&broadcasted_shape).unwrap(); + res + }; - Ok(lhs) + Ok(res) } } -impl + std::marker::Send + std::marker::Sync> Mul for Tensor { +impl + std::marker::Send + std::marker::Sync + IntoI64> Mul + for Tensor +{ type Output = Result, TensorError>; /// Elementwise multiplies tensors. /// # Arguments @@ -1436,18 +1673,28 @@ impl + std::marker::Send + std::marker::Sync> Mu /// ``` fn mul(self, rhs: Self) -> Self::Output { let broadcasted_shape = get_broadcasted_shape(self.dims(), rhs.dims()).unwrap(); - let mut lhs = self.expand(&broadcasted_shape).unwrap(); + let lhs = self.expand(&broadcasted_shape).unwrap(); let rhs = rhs.expand(&broadcasted_shape).unwrap(); - lhs.par_iter_mut().zip(rhs).for_each(|(o, r)| { - *o = o.clone() * r; - }); + #[cfg(feature = "metal")] + let res = metal_tensor_op(&lhs, &rhs, "mul"); + + #[cfg(not(feature = "metal"))] + let res = { + let mut res: Tensor = lhs + .par_iter() + .zip(rhs) + .map(|(o, r)| o.clone() * r) + .collect(); + res.reshape(&broadcasted_shape).unwrap(); + res + }; - Ok(lhs) + Ok(res) } } -impl + std::marker::Send + std::marker::Sync> Tensor { +impl + std::marker::Send + std::marker::Sync + IntoI64> Tensor { /// Elementwise raise a tensor to the nth power. /// # Arguments /// @@ -1661,4 +1908,66 @@ mod tests { let b = Tensor::::new(Some(&[1, 4]), &[2, 1]).unwrap(); assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b); } + + #[test] + #[cfg(feature = "metal")] + fn tensor_metal_int() { + let a = Tensor::::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap(); + let b = Tensor::::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap(); + let c = metal_tensor_op(&a, &b, "add"); + assert_eq!(c, Tensor::new(Some(&[2, 4, 6, 8]), &[2, 2]).unwrap()); + + let c = metal_tensor_op(&a, &b, "sub"); + assert_eq!(c, Tensor::new(Some(&[0, 0, 0, 0]), &[2, 2]).unwrap()); + + let c = metal_tensor_op(&a, &b, "mul"); + assert_eq!(c, Tensor::new(Some(&[1, 4, 9, 16]), &[2, 2]).unwrap()); + } + + #[test] + #[cfg(feature = "metal")] + fn tensor_metal_felt() { + use halo2curves::bn256::Fr; + + let a = Tensor::::new( + Some(&[Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)]), + &[2, 2], + ) + .unwrap(); + let b = Tensor::::new( + Some(&[Fr::from(1), Fr::from(2), Fr::from(3), Fr::from(4)]), + &[2, 2], + ) + .unwrap(); + + let c = metal_tensor_op(&a, &b, "add"); + assert_eq!( + c, + Tensor::::new( + Some(&[Fr::from(2), Fr::from(4), Fr::from(6), Fr::from(8)]), + &[2, 2], + ) + .unwrap() + ); + + let c = metal_tensor_op(&a, &b, "sub"); + assert_eq!( + c, + Tensor::::new( + Some(&[Fr::from(0), Fr::from(0), Fr::from(0), Fr::from(0)]), + &[2, 2], + ) + .unwrap() + ); + + let c = metal_tensor_op(&a, &b, "mul"); + assert_eq!( + c, + Tensor::::new( + Some(&[Fr::from(1), Fr::from(4), Fr::from(9), Fr::from(16)]), + &[2, 2], + ) + .unwrap() + ); + } } diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index 4b757e6d2..0b44054dc 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -1,4 +1,4 @@ -use super::TensorError; +use super::{IntoI64, TensorError}; use crate::tensor::{Tensor, TensorType}; use itertools::Itertools; use maybe_rayon::{iter::ParallelIterator, prelude::IntoParallelRefIterator}; @@ -13,88 +13,88 @@ pub use std::ops::{Add, Mul, Neg, Sub}; /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::trilu; -/// let a = Tensor::::new( +/// let a = Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 6]), /// &[1, 3, 2], /// ).unwrap(); /// let result = trilu(&a, 1, true).unwrap(); -/// let expected = Tensor::::new(Some(&[0, 2, 0, 0, 0, 0]), &[1, 3, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 2, 0, 0, 0, 0]), &[1, 3, 2]).unwrap(); /// assert_eq!(result, expected); /// /// let result = trilu(&a, 1, false).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[1, 3, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[1, 3, 2]).unwrap(); /// assert_eq!(result, expected); /// /// let result = trilu(&a, 0, true).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 2, 0, 4, 0, 0]), &[1, 3, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 2, 0, 4, 0, 0]), &[1, 3, 2]).unwrap(); /// assert_eq!(result, expected); /// /// let result = trilu(&a, 0, false).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 0, 3, 4, 5, 6]), &[1, 3, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 0, 3, 4, 5, 6]), &[1, 3, 2]).unwrap(); /// assert_eq!(result, expected); /// /// let result = trilu(&a, -1, true).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 2, 3, 4, 0, 6]), &[1, 3, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 2, 3, 4, 0, 6]), &[1, 3, 2]).unwrap(); /// assert_eq!(result, expected); /// /// let result = trilu(&a, -1, false).unwrap(); -/// let expected = Tensor::::new(Some(&[0, 0, 3, 0, 5, 6]), &[1, 3, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 0, 3, 0, 5, 6]), &[1, 3, 2]).unwrap(); /// assert_eq!(result, expected); /// -/// let a = Tensor::::new( +/// let a = Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 6]), /// &[1, 2, 3], /// ).unwrap(); /// let result = trilu(&a, 1, true).unwrap(); -/// let expected = Tensor::::new(Some(&[0, 2, 3, 0, 0, 6]), &[1, 2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 2, 3, 0, 0, 6]), &[1, 2, 3]).unwrap(); /// assert_eq!(result, expected); /// /// let result = trilu(&a, 1, false).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 2, 0, 4, 5, 6]), &[1, 2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 2, 0, 4, 5, 6]), &[1, 2, 3]).unwrap(); /// assert_eq!(result, expected); /// /// let result = trilu(&a, 0, true).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 2, 3, 0, 5, 6]), &[1, 2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 2, 3, 0, 5, 6]), &[1, 2, 3]).unwrap(); /// assert_eq!(result, expected); /// /// let result = trilu(&a, 0, false).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 0, 0, 4, 5, 0]), &[1, 2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 0, 0, 4, 5, 0]), &[1, 2, 3]).unwrap(); /// assert_eq!(result, expected); /// /// let result = trilu(&a, -1, true).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[1, 2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[1, 2, 3]).unwrap(); /// assert_eq!(result, expected); /// /// let result = trilu(&a, -1, false).unwrap(); -/// let expected = Tensor::::new(Some(&[0, 0, 0, 4, 0, 0]), &[1, 2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 0, 0, 4, 0, 0]), &[1, 2, 3]).unwrap(); /// assert_eq!(result, expected); /// -/// let a = Tensor::::new( +/// let a = Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9]), /// &[1, 3, 3], /// ).unwrap(); /// let result = trilu(&a, 1, true).unwrap(); -/// let expected = Tensor::::new(Some(&[0, 2, 3, 0, 0, 6, 0, 0, 0]), &[1, 3, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 2, 3, 0, 0, 6, 0, 0, 0]), &[1, 3, 3]).unwrap(); /// assert_eq!(result, expected); /// /// let result = trilu(&a, 1, false).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 2, 0, 4, 5, 6, 7, 8, 9]), &[1, 3, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 2, 0, 4, 5, 6, 7, 8, 9]), &[1, 3, 3]).unwrap(); /// assert_eq!(result, expected); /// /// let result = trilu(&a, 0, true).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 2, 3, 0, 5, 6, 0, 0, 9]), &[1, 3, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 2, 3, 0, 5, 6, 0, 0, 9]), &[1, 3, 3]).unwrap(); /// assert_eq!(result, expected); /// /// let result = trilu(&a, 0, false).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 0, 0, 4, 5, 0, 7, 8, 9]), &[1, 3, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 0, 0, 4, 5, 0, 7, 8, 9]), &[1, 3, 3]).unwrap(); /// assert_eq!(result, expected); /// /// let result = trilu(&a, -1, true).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6, 0, 8, 9]), &[1, 3, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6, 0, 8, 9]), &[1, 3, 3]).unwrap(); /// assert_eq!(result, expected); /// /// let result = trilu(&a, -1, false).unwrap(); -/// let expected = Tensor::::new(Some(&[0, 0, 0, 4, 0, 0, 7, 8, 0]), &[1, 3, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 0, 0, 4, 0, 0, 7, 8, 0]), &[1, 3, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` pub fn trilu( @@ -148,40 +148,40 @@ pub fn trilu( /// ``` /// /// -/// let a = Tensor::::new( +/// let a = Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 6]), /// &[2, 3], /// ).unwrap(); /// let result = resize(&a, &[1, 2]).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6]), &[2, 6]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6]), &[2, 6]).unwrap(); /// assert_eq!(result, expected); /// /// -/// let a = Tensor::::new( +/// let a = Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 6]), /// &[2, 3], /// ).unwrap(); /// let result = resize(&a, &[2, 2]).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 1, 2, 2, 3, 3, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6]), &[4, 6]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 1, 2, 2, 3, 3, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6]), &[4, 6]).unwrap(); /// assert_eq!(result, expected); /// /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::resize; -/// let a = Tensor::::new( +/// let a = Tensor::::new( /// Some(&[1, 2, 3, 4]), /// &[2, 2], /// ).unwrap(); /// let result = resize(&a, &[2, 2]).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4]), &[4, 4]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4]), &[4, 4]).unwrap(); /// assert_eq!(result, expected); /// /// -/// let a = Tensor::::new( +/// let a = Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 6]), /// &[3, 2], /// ).unwrap(); /// let result = resize(&a, &[2, 3]).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 5, 5, 5, 6, 6, 6]), &[6, 6]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 5, 5, 5, 6, 6, 6]), &[6, 6]).unwrap(); /// assert_eq!(result, expected); /// /// @@ -227,31 +227,31 @@ pub fn resize( /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::add; -/// let x = Tensor::::new( +/// let x = Tensor::::new( /// Some(&[2, 1, 2, 1, 1, 1]), /// &[2, 3], /// ).unwrap(); -/// let k = Tensor::::new( +/// let k = Tensor::::new( /// Some(&[2, 3, 2, 1, 1, 1]), /// &[2, 3], /// ).unwrap(); /// let result = add(&[x, k]).unwrap(); -/// let expected = Tensor::::new(Some(&[4, 4, 4, 2, 2, 2]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[4, 4, 4, 2, 2, 2]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// /// // Now test 1D casting -/// let x = Tensor::::new( +/// let x = Tensor::::new( /// Some(&[2, 1, 2, 1, 1, 1]), /// &[2, 3], /// ).unwrap(); -/// let k = Tensor::::new( +/// let k = Tensor::::new( /// Some(&[2]), /// &[1]).unwrap(); /// let result = add(&[x, k]).unwrap(); -/// let expected = Tensor::::new(Some(&[4, 3, 4, 3, 3, 3]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[4, 3, 4, 3, 3, 3]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` -pub fn add + std::marker::Send + std::marker::Sync>( +pub fn add + std::marker::Send + std::marker::Sync + IntoI64>( t: &[Tensor], ) -> Result, TensorError> { // calculate value of output @@ -273,32 +273,32 @@ pub fn add + std::marker::Send + std::marker::Sy /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::sub; -/// let x = Tensor::::new( +/// let x = Tensor::::new( /// Some(&[2, 1, 2, 1, 1, 1]), /// &[2, 3], /// ).unwrap(); -/// let k = Tensor::::new( +/// let k = Tensor::::new( /// Some(&[2, 3, 2, 1, 1, 1]), /// &[2, 3], /// ).unwrap(); /// let result = sub(&[x, k]).unwrap(); -/// let expected = Tensor::::new(Some(&[0, -2, 0, 0, 0, 0]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[0, -2, 0, 0, 0, 0]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// /// // Now test 1D sub -/// let x = Tensor::::new( +/// let x = Tensor::::new( /// Some(&[2, 1, 2, 1, 1, 1]), /// &[2, 3], /// ).unwrap(); -/// let k = Tensor::::new( +/// let k = Tensor::::new( /// Some(&[2]), /// &[1], /// ).unwrap(); /// let result = sub(&[x, k]).unwrap(); -/// let expected = Tensor::::new(Some(&[0, -1, 0, -1, -1, -1]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[0, -1, 0, -1, -1, -1]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` -pub fn sub + std::marker::Send + std::marker::Sync>( +pub fn sub + std::marker::Send + std::marker::Sync + IntoI64>( t: &[Tensor], ) -> Result, TensorError> { // calculate value of output @@ -319,31 +319,31 @@ pub fn sub + std::marker::Send + std::marker::Sy /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::mult; -/// let x = Tensor::::new( +/// let x = Tensor::::new( /// Some(&[2, 1, 2, 1, 1, 1]), /// &[2, 3], /// ).unwrap(); -/// let k = Tensor::::new( +/// let k = Tensor::::new( /// Some(&[2, 3, 2, 1, 1, 1]), /// &[2, 3], /// ).unwrap(); /// let result = mult(&[x, k]).unwrap(); -/// let expected = Tensor::::new(Some(&[4, 3, 4, 1, 1, 1]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[4, 3, 4, 1, 1, 1]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// /// // Now test 1D mult -/// let x = Tensor::::new( +/// let x = Tensor::::new( /// Some(&[2, 1, 2, 1, 1, 1]), /// &[2, 3], /// ).unwrap(); -/// let k = Tensor::::new( +/// let k = Tensor::::new( /// Some(&[2]), /// &[1]).unwrap(); /// let result = mult(&[x, k]).unwrap(); -/// let expected = Tensor::::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` -pub fn mult + std::marker::Send + std::marker::Sync>( +pub fn mult + std::marker::Send + std::marker::Sync + IntoI64>( t: &[Tensor], ) -> Result, TensorError> { // calculate value of output @@ -366,24 +366,24 @@ pub fn mult + std::marker::Send + std::marker::S /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::downsample; -/// let x = Tensor::::new( +/// let x = Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 6]), /// &[2, 3], /// ).unwrap(); /// let result = downsample(&x, 0, 1, 1).unwrap(); -/// let expected = Tensor::::new(Some(&[4, 5, 6]), &[1, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[4, 5, 6]), &[1, 3]).unwrap(); /// assert_eq!(result, expected); /// /// let result = downsample(&x, 1, 2, 0).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 3, 4, 6]), &[2, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 3, 4, 6]), &[2, 2]).unwrap(); /// assert_eq!(result, expected); /// /// let result = downsample(&x, 1, 2, 1).unwrap(); -/// let expected = Tensor::::new(Some(&[2, 5]), &[2, 1]).unwrap(); +/// let expected = Tensor::::new(Some(&[2, 5]), &[2, 1]).unwrap(); /// assert_eq!(result, expected); /// /// let result = downsample(&x, 1, 2, 2).unwrap(); -/// let expected = Tensor::::new(Some(&[3, 6]), &[2, 1]).unwrap(); +/// let expected = Tensor::::new(Some(&[3, 6]), &[2, 1]).unwrap(); /// assert_eq!(result, expected); pub fn downsample( input: &Tensor, @@ -435,7 +435,7 @@ pub fn downsample( /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::gather; -/// let x = Tensor::::new( +/// let x = Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 6]), /// &[2, 3], /// ).unwrap(); @@ -444,7 +444,7 @@ pub fn downsample( /// &[2], /// ).unwrap(); /// let result = gather(&x, &index, 1).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 2, 4, 5]), &[2, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 2, 4, 5]), &[2, 2]).unwrap(); /// assert_eq!(result, expected); /// ``` pub fn gather( @@ -562,7 +562,7 @@ pub fn scatter( /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::gather_elements; -/// let x = Tensor::::new( +/// let x = Tensor::::new( /// Some(&[1, 2, 3, 4]), /// &[2, 2], /// ).unwrap(); @@ -571,7 +571,7 @@ pub fn scatter( /// &[2, 2], /// ).unwrap(); /// let result = gather_elements(&x, &index, 1).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 1, 4, 3]), &[2, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 1, 4, 3]), &[2, 2]).unwrap(); /// assert_eq!(result, expected); /// ``` pub fn gather_elements( @@ -619,7 +619,7 @@ pub fn gather_elements( /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::gather_nd; -/// let x = Tensor::::new( +/// let x = Tensor::::new( /// Some(&[0, 1, 2, 3]), /// &[2, 2], /// ).unwrap(); @@ -628,7 +628,7 @@ pub fn gather_elements( /// &[2, 2], /// ).unwrap(); /// let result = gather_nd(&x, &index, 0).unwrap(); -/// let expected = Tensor::::new(Some(&[0, 3]), &[2]).unwrap(); +/// let expected = Tensor::::new(Some(&[0, 3]), &[2]).unwrap(); /// assert_eq!(result, expected); /// /// let index = Tensor::::new( @@ -636,10 +636,10 @@ pub fn gather_elements( /// &[2, 1], /// ).unwrap(); /// let result = gather_nd(&x, &index, 0).unwrap(); -/// let expected = Tensor::::new(Some(&[2, 3, 0, 1]), &[2, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[2, 3, 0, 1]), &[2, 2]).unwrap(); /// assert_eq!(result, expected); /// -/// let x = Tensor::::new( +/// let x = Tensor::::new( /// Some(&[0, 1, 2, 3, 4, 5, 6, 7]), /// &[2, 2, 2], /// ).unwrap(); @@ -648,7 +648,7 @@ pub fn gather_elements( /// &[2, 2], /// ).unwrap(); /// let result = gather_nd(&x, &index, 0).unwrap(); -/// let expected = Tensor::::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap(); /// assert_eq!(result, expected); /// /// let index = Tensor::::new( @@ -656,7 +656,7 @@ pub fn gather_elements( /// &[2, 1, 2], /// ).unwrap(); /// let result = gather_nd(&x, &index, 0).unwrap(); -/// let expected = Tensor::::new(Some(&[2, 3, 4, 5]), &[2, 1, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[2, 3, 4, 5]), &[2, 1, 2]).unwrap(); /// assert_eq!(result, expected); /// /// let index = Tensor::::new( @@ -664,7 +664,7 @@ pub fn gather_elements( /// &[2, 1], /// ).unwrap(); /// let result = gather_nd(&x, &index, 1).unwrap(); -/// let expected = Tensor::::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap(); /// assert_eq!(result, expected); /// /// let index = Tensor::::new( @@ -672,7 +672,7 @@ pub fn gather_elements( /// &[2, 2, 3], /// ).unwrap(); /// let result = gather_nd(&x, &index, 0).unwrap(); -/// let expected = Tensor::::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[2, 3, 4, 5]), &[2, 2]).unwrap(); /// assert_eq!(result, expected); /// /// let index = Tensor::::new( @@ -680,7 +680,7 @@ pub fn gather_elements( /// &[2, 2, 2], /// ).unwrap(); /// let result = gather_nd(&x, &index, 0).unwrap(); -/// let expected = Tensor::::new(Some(&[2, 3, 0, 1, 6, 7, 4, 5]), &[2, 2, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[2, 3, 0, 1, 6, 7, 4, 5]), &[2, 2, 2]).unwrap(); /// assert_eq!(result, expected); /// /// let index = Tensor::::new( @@ -688,7 +688,7 @@ pub fn gather_elements( /// &[2, 3], /// ).unwrap(); /// let result = gather_nd(&x, &index, 0).unwrap(); -/// let expected = Tensor::::new(Some(&[2, 7]), &[2]).unwrap(); +/// let expected = Tensor::::new(Some(&[2, 7]), &[2]).unwrap(); /// assert_eq!(result, expected); /// pub fn gather_nd( @@ -799,7 +799,7 @@ pub fn gather_nd( /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::scatter_nd; -/// let x = Tensor::::new( +/// let x = Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 6, 7, 8]), /// &[8], /// ).unwrap(); @@ -808,15 +808,15 @@ pub fn gather_nd( /// Some(&[4, 3, 1, 7]), /// &[4, 1], /// ).unwrap(); -/// let src = Tensor::::new( +/// let src = Tensor::::new( /// Some(&[9, 10, 11, 12]), /// &[4], /// ).unwrap(); /// let result = scatter_nd(&x, &index, &src).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 11, 3, 10, 9, 6, 7, 12]), &[8]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 11, 3, 10, 9, 6, 7, 12]), &[8]).unwrap(); /// assert_eq!(result, expected); /// -/// let x = Tensor::::new( +/// let x = Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, /// 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, /// 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8, @@ -829,7 +829,7 @@ pub fn gather_nd( /// &[2, 1], /// ).unwrap(); /// -/// let src = Tensor::::new( +/// let src = Tensor::::new( /// Some(&[5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, /// 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, /// ]), @@ -838,7 +838,7 @@ pub fn gather_nd( /// /// let result = scatter_nd(&x, &index, &src).unwrap(); /// -/// let expected = Tensor::::new( +/// let expected = Tensor::::new( /// Some(&[5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, /// 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, /// 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, @@ -847,7 +847,7 @@ pub fn gather_nd( /// ).unwrap(); /// assert_eq!(result, expected); /// -/// let x = Tensor::::new( +/// let x = Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 6, 7, 8]), /// &[2, 4], /// ).unwrap(); @@ -856,15 +856,15 @@ pub fn gather_nd( /// Some(&[0, 1]), /// &[2, 1], /// ).unwrap(); -/// let src = Tensor::::new( +/// let src = Tensor::::new( /// Some(&[9, 10]), /// &[2], /// ).unwrap(); /// let result = scatter_nd(&x, &index, &src).unwrap(); -/// let expected = Tensor::::new(Some(&[9, 9, 9, 9, 10, 10, 10, 10]), &[2, 4]).unwrap(); +/// let expected = Tensor::::new(Some(&[9, 9, 9, 9, 10, 10, 10, 10]), &[2, 4]).unwrap(); /// assert_eq!(result, expected); /// -/// let x = Tensor::::new( +/// let x = Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 6, 7, 8]), /// &[2, 4], /// ).unwrap(); @@ -873,12 +873,12 @@ pub fn gather_nd( /// Some(&[0, 1]), /// &[1, 1, 2], /// ).unwrap(); -/// let src = Tensor::::new( +/// let src = Tensor::::new( /// Some(&[9]), /// &[1, 1], /// ).unwrap(); /// let result = scatter_nd(&x, &index, &src).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 9, 3, 4, 5, 6, 7, 8]), &[2, 4]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 9, 3, 4, 5, 6, 7, 8]), &[2, 4]).unwrap(); /// assert_eq!(result, expected); /// ```` /// @@ -927,12 +927,12 @@ pub fn scatter_nd( /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::abs; -/// let x = Tensor::::new( +/// let x = Tensor::::new( /// Some(&[-2, 15, 2, -1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = abs(&x).unwrap(); -/// let expected = Tensor::::new(Some(&[2, 15, 2, 1, 1, 0]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[2, 15, 2, 1, 1, 0]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` pub fn abs + std::cmp::Ord + Neg>( @@ -953,14 +953,14 @@ pub fn abs + std::cmp::Ord + Neg>( /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::intercalate_values; /// -/// let tensor = Tensor::::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap(); +/// let tensor = Tensor::::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap(); /// let result = intercalate_values(&tensor, 0, 2, 1).unwrap(); /// -/// let expected = Tensor::::new(Some(&[1, 0, 2, 3, 0, 4]), &[2, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 0, 2, 3, 0, 4]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// /// let result = intercalate_values(&expected, 0, 2, 0).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 0, 2, 0, 0, 0, 3, 0, 4]), &[3, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 0, 2, 0, 0, 0, 3, 0, 4]), &[3, 3]).unwrap(); /// /// assert_eq!(result, expected); /// @@ -1006,23 +1006,23 @@ pub fn intercalate_values( /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::one_hot; -/// let tensor = Tensor::::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap(); +/// let tensor = Tensor::::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap(); /// let result = one_hot(&tensor, 5, 2).unwrap(); -/// let expected = Tensor::::new(Some(&[0, 1, 0, 0, 0, +/// let expected = Tensor::::new(Some(&[0, 1, 0, 0, 0, /// 0, 0, 1, 0, 0, /// 0, 0, 0, 1, 0, /// 0, 0, 0, 0, 1]), &[2, 2, 5]).unwrap(); /// assert_eq!(result, expected); /// ``` pub fn one_hot( - tensor: &Tensor, + tensor: &Tensor, num_classes: usize, axis: usize, -) -> Result, TensorError> { +) -> Result, TensorError> { let mut output_dims = tensor.dims().to_vec(); output_dims.insert(axis, num_classes); - let mut output: Tensor = Tensor::new(None, &output_dims)?; + let mut output: Tensor = Tensor::new(None, &output_dims)?; let cartesian_coord = output .dims() @@ -1071,18 +1071,18 @@ pub fn one_hot( /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::pad; /// -/// let x = Tensor::::new( +/// let x = Tensor::::new( /// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]), /// &[1, 1, 3, 3], /// ).unwrap(); -/// let result = pad::(&x, vec![(0, 0), (0, 0), (1, 1), (1, 1)], 0).unwrap(); -/// let expected = Tensor::::new( +/// let result = pad::(&x, vec![(0, 0), (0, 0), (1, 1), (1, 1)], 0).unwrap(); +/// let expected = Tensor::::new( /// Some(&[0, 0, 0, 0, 0, 0, 5, 2, 3, 0, 0, 0, 4, -1, 0, 0, 3, 1, 6, 0, 0, 0, 0, 0, 0]), /// &[1, 1, 5, 5], /// ).unwrap(); /// assert_eq!(result, expected); /// -/// let result = pad::(&x, vec![(1, 1), (1, 1)], 2).unwrap(); +/// let result = pad::(&x, vec![(1, 1), (1, 1)], 2).unwrap(); /// assert_eq!(result, expected); /// ``` pub fn pad( @@ -1132,33 +1132,33 @@ pub fn pad( /// // tested against pytorch outputs for reference :) /// /// // 1D example -/// let x = Tensor::::new(Some(&[1, 2, 3]), &[3]).unwrap(); -/// let y = Tensor::::new(Some(&[4, 5, 6]), &[3]).unwrap(); +/// let x = Tensor::::new(Some(&[1, 2, 3]), &[3]).unwrap(); +/// let y = Tensor::::new(Some(&[4, 5, 6]), &[3]).unwrap(); /// let result = concat(&[&x, &y], 0).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap(); /// assert_eq!(result, expected); /// /// // 2D example -/// let x = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 2]).unwrap(); -/// let y = Tensor::::new(Some(&[7, 8, 9]), &[3, 1]).unwrap(); +/// let x = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 2]).unwrap(); +/// let y = Tensor::::new(Some(&[7, 8, 9]), &[3, 1]).unwrap(); /// let result = concat(&[&x, &y], 1).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 2, 7, 3, 4, 8, 5, 6, 9]), &[3, 3]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 2, 7, 3, 4, 8, 5, 6, 9]), &[3, 3]).unwrap(); /// assert_eq!(result, expected); /// /// /// 4D example -/// let x = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), &[2, 2, 2, 2]).unwrap(); -/// let y = Tensor::::new(Some(&[17, 18, 19, 20, 21, 22, 23, 14]), &[2, 2, 1, 2]).unwrap(); +/// let x = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), &[2, 2, 2, 2]).unwrap(); +/// let y = Tensor::::new(Some(&[17, 18, 19, 20, 21, 22, 23, 14]), &[2, 2, 1, 2]).unwrap(); /// let result = concat(&[&x, &y], 2).unwrap(); -/// let expected = Tensor::::new(Some(&[1, 2, 3, 4, 17, 18, 5, 6, 7, 8, 19, 20, 9, 10, 11, 12, 21, 22, 13, 14, 15, 16, 23, 14]), &[2, 2, 3, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 2, 3, 4, 17, 18, 5, 6, 7, 8, 19, 20, 9, 10, 11, 12, 21, 22, 13, 14, 15, 16, 23, 14]), &[2, 2, 3, 2]).unwrap(); /// assert_eq!(result, expected); /// /// /// // 5D example -/// let x = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), &[8, 1, 1, 1, 2]).unwrap(); -/// let y = Tensor::::new(Some(&[17, 18, 19, 20, 21, 22, 23, 14]), &[4, 1, 1, 1, 2]).unwrap(); +/// let x = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), &[8, 1, 1, 1, 2]).unwrap(); +/// let y = Tensor::::new(Some(&[17, 18, 19, 20, 21, 22, 23, 14]), &[4, 1, 1, 1, 2]).unwrap(); /// let result = concat(&[&x, &y], 0).unwrap(); /// -/// let expected = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 14]), &[12, 1, 1, 1, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 14]), &[12, 1, 1, 1, 2]).unwrap(); /// assert_eq!(result, expected); /// /// ``` @@ -1231,19 +1231,19 @@ pub fn concat( /// // tested against pytorch output /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::slice; -/// let x = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 2]).unwrap(); +/// let x = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 2]).unwrap(); /// let result = slice(&x, &0, &1, &2).unwrap(); -/// let expected = Tensor::::new(Some(&[3, 4]), &[1, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[3, 4]), &[1, 2]).unwrap(); /// assert_eq!(result, expected); /// -/// let x = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 2]).unwrap(); +/// let x = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 2]).unwrap(); /// let result = slice(&x, &1, &1, &2).unwrap(); -/// let expected = Tensor::::new(Some(&[2, 4, 6]), &[3, 1]).unwrap(); +/// let expected = Tensor::::new(Some(&[2, 4, 6]), &[3, 1]).unwrap(); /// assert_eq!(result, expected); /// -/// let x = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 2, 3]).unwrap(); +/// let x = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 2, 3]).unwrap(); /// let result = slice(&x, &2, &1, &2).unwrap(); -/// let expected = Tensor::::new(Some(&[2, 5, 8, 11]), &[2, 2, 1]).unwrap(); +/// let expected = Tensor::::new(Some(&[2, 5, 8, 11]), &[2, 2, 1]).unwrap(); /// assert_eq!(result, expected); /// ``` /// @@ -1285,19 +1285,19 @@ pub mod nonlinearities { /// use ezkl::tensor::Tensor; /// /// use ezkl::tensor::ops::nonlinearities::ceil; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 6]), /// &[3, 2], /// ).unwrap(); /// let result = ceil(&x, 2.0); - /// let expected = Tensor::::new(Some(&[2, 2, 4, 4, 6, 6]), &[3, 2]).unwrap(); + /// let expected = Tensor::::new(Some(&[2, 2, 4, 4, 6, 6]), &[3, 2]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn ceil(a: &Tensor, scale: f64) -> Tensor { + pub fn ceil(a: &Tensor, scale: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale; let rounded = kix.ceil() * scale; - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1310,19 +1310,19 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::floor; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 6]), /// &[3, 2], /// ).unwrap(); /// let result = floor(&x, 2.0); - /// let expected = Tensor::::new(Some(&[0, 2, 2, 4, 4, 6]), &[3, 2]).unwrap(); + /// let expected = Tensor::::new(Some(&[0, 2, 2, 4, 4, 6]), &[3, 2]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn floor(a: &Tensor, scale: f64) -> Tensor { + pub fn floor(a: &Tensor, scale: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale; let rounded = kix.floor() * scale; - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1335,19 +1335,19 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::round; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 6]), /// &[3, 2], /// ).unwrap(); /// let result = round(&x, 2.0); - /// let expected = Tensor::::new(Some(&[2, 2, 4, 4, 6, 6]), &[3, 2]).unwrap(); + /// let expected = Tensor::::new(Some(&[2, 2, 4, 4, 6, 6]), &[3, 2]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn round(a: &Tensor, scale: f64) -> Tensor { + pub fn round(a: &Tensor, scale: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale; let rounded = kix.round() * scale; - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1360,19 +1360,19 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::round_half_to_even; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[1, 2, 3, 4, 5, 6]), /// &[3, 2], /// ).unwrap(); /// let result = round_half_to_even(&x, 2.0); - /// let expected = Tensor::::new(Some(&[0, 2, 4, 4, 4, 6]), &[3, 2]).unwrap(); + /// let expected = Tensor::::new(Some(&[0, 2, 4, 4, 4, 6]), &[3, 2]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn round_half_to_even(a: &Tensor, scale: f64) -> Tensor { + pub fn round_half_to_even(a: &Tensor, scale: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale; let rounded = kix.round_ties_even() * scale; - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1385,20 +1385,20 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::pow; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[2, 15, 2, 1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = pow(&x, 1.0, 2.0); - /// let expected = Tensor::::new(Some(&[4, 225, 4, 1, 1, 0]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[4, 225, 4, 1, 1, 0]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn pow(a: &Tensor, scale_input: f64, power: f64) -> Tensor { + pub fn pow(a: &Tensor, scale_input: f64, power: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale_input; let kix = scale_input * (kix).powf(power); let rounded = kix.round(); - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1410,12 +1410,12 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::kronecker_delta; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[2, 15, 2, 1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = kronecker_delta(&x); - /// let expected = Tensor::::new(Some(&[0, 0, 0, 0, 0, 1]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[0, 0, 0, 0, 0, 1]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` pub fn kronecker_delta( @@ -1441,37 +1441,37 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::sigmoid; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[2, 15, 2, 1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = sigmoid(&x, 1.0); - /// let expected = Tensor::::new(Some(&[1, 1, 1, 1, 1, 1]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[1, 1, 1, 1, 1, 1]), &[2, 3]).unwrap(); /// /// assert_eq!(result, expected); - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[65536]), /// &[1], /// ).unwrap(); /// let result = sigmoid(&x, 65536.0); - /// let expected = Tensor::::new(Some(&[47911]), &[1]).unwrap(); + /// let expected = Tensor::::new(Some(&[47911]), &[1]).unwrap(); /// assert_eq!(result, expected); /// /// /// assert_eq!(result, expected); - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[256]), /// &[1], /// ).unwrap(); /// let result = sigmoid(&x, 256.0); - /// let expected = Tensor::::new(Some(&[187]), &[1]).unwrap(); + /// let expected = Tensor::::new(Some(&[187]), &[1]).unwrap(); /// /// ``` - pub fn sigmoid(a: &Tensor, scale_input: f64) -> Tensor { + pub fn sigmoid(a: &Tensor, scale_input: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale_input; let fout = scale_input / (1.0 + (-kix).exp()); let rounded = fout.round(); - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1491,17 +1491,17 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::hardswish; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[-12, -3, 2, 1, 1, 15]), /// &[2, 3], /// ).unwrap(); /// let result = hardswish(&x, 1.0); - /// let expected = Tensor::::new(Some(&[0, 0, 2, 1, 1, 15]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[0, 0, 2, 1, 1, 15]), &[2, 3]).unwrap(); /// /// assert_eq!(result, expected); /// /// ``` - pub fn hardswish(a: &Tensor, scale_input: f64) -> Tensor { + pub fn hardswish(a: &Tensor, scale_input: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale_input; let res = if kix <= -3.0 { @@ -1512,7 +1512,7 @@ pub mod nonlinearities { kix * (kix + 3.0) / 6.0 }; let rounded = (res * scale_input).round(); - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1527,31 +1527,31 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::exp; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[2, 15, 2, 1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = exp(&x, 1.0); - /// let expected = Tensor::::new(Some(&[7, 3269017, 7, 3, 3, 1]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[7, 3269017, 7, 3, 3, 1]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// /// - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[37, 12, 41]), /// &[3], /// ).unwrap(); /// let result = exp(&x, 512.0); /// - /// let expected = Tensor::::new(Some(&[550, 524, 555]), &[3]).unwrap(); + /// let expected = Tensor::::new(Some(&[550, 524, 555]), &[3]).unwrap(); /// /// assert_eq!(result, expected); /// ``` - pub fn exp(a: &Tensor, scale_input: f64) -> Tensor { + pub fn exp(a: &Tensor, scale_input: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale_input; let fout = scale_input * kix.exp(); let rounded = fout.round(); - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1566,31 +1566,31 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::ln; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[2, 15, 2, 1, 1, 3000]), /// &[2, 3], /// ).unwrap(); /// let result = ln(&x, 1.0); - /// let expected = Tensor::::new(Some(&[1, 3, 1, 0, 0, 8]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[1, 3, 1, 0, 0, 8]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// /// - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[37, 12, 41]), /// &[3], /// ).unwrap(); /// let result = ln(&x, 512.0); /// - /// let expected = Tensor::::new(Some(&[-1345, -1922, -1293]), &[3]).unwrap(); + /// let expected = Tensor::::new(Some(&[-1345, -1922, -1293]), &[3]).unwrap(); /// /// assert_eq!(result, expected); /// ``` - pub fn ln(a: &Tensor, scale_input: f64) -> Tensor { + pub fn ln(a: &Tensor, scale_input: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale_input; let fout = scale_input * kix.ln(); let rounded = fout.round(); - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1602,15 +1602,15 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::sign; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[-2, 15, 2, 1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = sign(&x); - /// let expected = Tensor::::new(Some(&[-1, 1, 1, 1, 1, 0]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[-1, 1, 1, 1, 1, 0]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn sign(a: &Tensor) -> Tensor { + pub fn sign(a: &Tensor) -> Tensor { a.par_enum_map(|_, a_i| Ok::<_, TensorError>(a_i.signum())) .unwrap() } @@ -1625,20 +1625,20 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::sqrt; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[4, 25, 8, 1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = sqrt(&x, 1.0); - /// let expected = Tensor::::new(Some(&[2, 5, 3, 1, 1, 0]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[2, 5, 3, 1, 1, 0]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn sqrt(a: &Tensor, scale_input: f64) -> Tensor { + pub fn sqrt(a: &Tensor, scale_input: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale_input; let fout = scale_input * kix.sqrt(); let rounded = fout.round(); - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1653,20 +1653,20 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::rsqrt; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[4, 25, 8, 1, 1, 1]), /// &[2, 3], /// ).unwrap(); /// let result = rsqrt(&x, 1.0); - /// let expected = Tensor::::new(Some(&[1, 0, 0, 1, 1, 1]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[1, 0, 0, 1, 1, 1]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn rsqrt(a: &Tensor, scale_input: f64) -> Tensor { + pub fn rsqrt(a: &Tensor, scale_input: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale_input; let fout = scale_input / (kix.sqrt() + f64::EPSILON); let rounded = fout.round(); - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1680,20 +1680,20 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::cos; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[4, 25, 8, 1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = cos(&x, 2.0); - /// let expected = Tensor::::new(Some(& [-1, 2, -1, 2, 2, 2]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(& [-1, 2, -1, 2, 2, 2]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn cos(a: &Tensor, scale_input: f64) -> Tensor { + pub fn cos(a: &Tensor, scale_input: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale_input; let fout = scale_input * kix.cos(); let rounded = fout.round(); - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1707,20 +1707,20 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::acos; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[4, 25, 8, 1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = acos(&x, 1.0); - /// let expected = Tensor::::new(Some(&[0, 0, 0, 0, 0, 2]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[0, 0, 0, 0, 0, 2]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn acos(a: &Tensor, scale_input: f64) -> Tensor { + pub fn acos(a: &Tensor, scale_input: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale_input; let fout = scale_input * kix.acos(); let rounded = fout.round(); - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1734,20 +1734,20 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::cosh; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[4, 25, 8, 1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = cosh(&x, 1.0); - /// let expected = Tensor::::new(Some(&[27, 36002449669, 1490, 2, 2, 1]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[27, 36002449669, 1490, 2, 2, 1]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn cosh(a: &Tensor, scale_input: f64) -> Tensor { + pub fn cosh(a: &Tensor, scale_input: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale_input; let fout = scale_input * kix.cosh(); let rounded = fout.round(); - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1761,20 +1761,20 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::acosh; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[4, 25, 8, 1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = acosh(&x, 1.0); - /// let expected = Tensor::::new(Some(& [2, 4, 3, 0, 0, 0]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(& [2, 4, 3, 0, 0, 0]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn acosh(a: &Tensor, scale_input: f64) -> Tensor { + pub fn acosh(a: &Tensor, scale_input: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale_input; let fout = scale_input * kix.acosh(); let rounded = fout.round(); - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1788,20 +1788,20 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::sin; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[4, 25, 8, 1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = sin(&x, 128.0); - /// let expected = Tensor::::new(Some(&[4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn sin(a: &Tensor, scale_input: f64) -> Tensor { + pub fn sin(a: &Tensor, scale_input: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale_input; let fout = scale_input * kix.sin(); let rounded = fout.round(); - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1815,20 +1815,20 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::asin; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[4, 25, 8, 1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = asin(&x, 128.0); - /// let expected = Tensor::::new(Some(& [4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(& [4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn asin(a: &Tensor, scale_input: f64) -> Tensor { + pub fn asin(a: &Tensor, scale_input: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale_input; let fout = scale_input * kix.asin(); let rounded = fout.round(); - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1842,20 +1842,20 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::sinh; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[4, 25, 8, 1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = sinh(&x, 2.0); - /// let expected = Tensor::::new(Some(&[7, 268337, 55, 1, 1, 0]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[7, 268337, 55, 1, 1, 0]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn sinh(a: &Tensor, scale_input: f64) -> Tensor { + pub fn sinh(a: &Tensor, scale_input: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale_input; let fout = scale_input * kix.sinh(); let rounded = fout.round(); - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1869,20 +1869,20 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::asinh; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[4, 25, 8, 1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = asinh(&x, 128.0); - /// let expected = Tensor::::new(Some(&[4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn asinh(a: &Tensor, scale_input: f64) -> Tensor { + pub fn asinh(a: &Tensor, scale_input: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale_input; let fout = scale_input * kix.asinh(); let rounded = fout.round(); - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1896,20 +1896,20 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::tan; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[4, 25, 8, 1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = tan(&x, 64.0); - /// let expected = Tensor::::new(Some(&[4, 26, 8, 1, 1, 0]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[4, 26, 8, 1, 1, 0]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn tan(a: &Tensor, scale_input: f64) -> Tensor { + pub fn tan(a: &Tensor, scale_input: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale_input; let fout = scale_input * kix.tan(); let rounded = fout.round(); - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1923,20 +1923,20 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::atan; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[4, 25, 8, 1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = atan(&x, 128.0); - /// let expected = Tensor::::new(Some(&[4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn atan(a: &Tensor, scale_input: f64) -> Tensor { + pub fn atan(a: &Tensor, scale_input: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale_input; let fout = scale_input * kix.atan(); let rounded = fout.round(); - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1951,21 +1951,21 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::tanh; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[4, 25, 8, 1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = tanh(&x, 128.0); - /// let expected = Tensor::::new(Some(&[4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[4, 25, 8, 1, 1, 0]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn tanh(a: &Tensor, scale_input: f64) -> Tensor { + pub fn tanh(a: &Tensor, scale_input: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale_input; let fout = scale_input * kix.tanh(); let rounded = fout.round(); - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -1980,21 +1980,21 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::atanh; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[4, 25, 8, 2, 2, 0]), /// &[2, 3], /// ).unwrap(); /// let result = atanh(&x, 32.0); - /// let expected = Tensor::::new(Some(&[4, 34, 8, 2, 2, 0]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[4, 34, 8, 2, 2, 0]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn atanh(a: &Tensor, scale_input: f64) -> Tensor { + pub fn atanh(a: &Tensor, scale_input: f64) -> Tensor { a.par_enum_map(|_, a_i| { let kix = (a_i as f64) / scale_input; let fout = scale_input * kix.atanh(); let rounded = fout.round(); - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -2009,15 +2009,15 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::erffunc; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[5, 28, 9, 1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = erffunc(&x, 128.0); - /// let expected = Tensor::::new(Some(&[6, 31, 10, 1, 1, 0]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[6, 31, 10, 1, 1, 0]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn erffunc(a: &Tensor, scale_input: f64) -> Tensor { + pub fn erffunc(a: &Tensor, scale_input: f64) -> Tensor { const NCOEF: usize = 28; const COF: [f64; 28] = [ -1.3026537197817094, @@ -2078,7 +2078,7 @@ pub mod nonlinearities { let kix = (a_i as f64) / scale_input; let fout = scale_input * erf(kix); let rounded = fout.round(); - Ok::<_, TensorError>(rounded as i128) + Ok::<_, TensorError>(rounded as i64) }) .unwrap() } @@ -2093,22 +2093,22 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::leakyrelu; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[2, 15, 2, 1, 1, -5]), /// &[2, 3], /// ).unwrap(); /// let result = leakyrelu(&x, 0.1); - /// let expected = Tensor::::new(Some(&[2, 15, 2, 1, 1, -1]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[2, 15, 2, 1, 1, -1]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn leakyrelu(a: &Tensor, slope: f64) -> Tensor { + pub fn leakyrelu(a: &Tensor, slope: f64) -> Tensor { a.par_enum_map(|_, a_i| { let rounded = if a_i < 0 { let d_inv_x = (slope) * (a_i as f64); - d_inv_x.round() as i128 + d_inv_x.round() as i64 } else { let d_inv_x = a_i as f64; - d_inv_x.round() as i128 + d_inv_x.round() as i64 }; Ok::<_, TensorError>(rounded) }) @@ -2123,22 +2123,22 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::max; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[2, 15, 2, 1, 1, -5]), /// &[2, 3], /// ).unwrap(); /// let result = max(&x, 1.0, 1.0); - /// let expected = Tensor::::new(Some(&[2, 15, 2, 1, 1, 1]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[2, 15, 2, 1, 1, 1]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn max(a: &Tensor, scale_input: f64, threshold: f64) -> Tensor { + pub fn max(a: &Tensor, scale_input: f64, threshold: f64) -> Tensor { // calculate value of output a.par_enum_map(|_, a_i| { let d_inv_x = (a_i as f64) / scale_input; let rounded = if d_inv_x <= threshold { - (threshold * scale_input).round() as i128 + (threshold * scale_input).round() as i64 } else { - (d_inv_x * scale_input).round() as i128 + (d_inv_x * scale_input).round() as i64 }; Ok::<_, TensorError>(rounded) }) @@ -2153,22 +2153,22 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::min; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[2, 15, 2, 1, 1, -5]), /// &[2, 3], /// ).unwrap(); /// let result = min(&x, 1.0, 2.0); - /// let expected = Tensor::::new(Some(&[2, 2, 2, 1, 1, -5]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[2, 2, 2, 1, 1, -5]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn min(a: &Tensor, scale_input: f64, threshold: f64) -> Tensor { + pub fn min(a: &Tensor, scale_input: f64, threshold: f64) -> Tensor { // calculate value of output a.par_enum_map(|_, a_i| { let d_inv_x = (a_i as f64) / scale_input; let rounded = if d_inv_x >= threshold { - (threshold * scale_input).round() as i128 + (threshold * scale_input).round() as i64 } else { - (d_inv_x * scale_input).round() as i128 + (d_inv_x * scale_input).round() as i64 }; Ok::<_, TensorError>(rounded) }) @@ -2184,19 +2184,19 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::const_div; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[2, 1, 2, 7, 1, 1]), /// &[2, 3], /// ).unwrap(); /// let k = 2.0; /// let result = const_div(&x, k); - /// let expected = Tensor::::new(Some(&[1, 1, 1, 4, 1, 1]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[1, 1, 1, 4, 1, 1]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn const_div(a: &Tensor, denom: f64) -> Tensor { + pub fn const_div(a: &Tensor, denom: f64) -> Tensor { a.par_enum_map(|_, a_i| { let d_inv_x = (a_i as f64) / (denom); - Ok::<_, TensorError>(d_inv_x.round() as i128) + Ok::<_, TensorError>(d_inv_x.round() as i64) }) .unwrap() } @@ -2210,21 +2210,21 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::recip; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[2, 1, 2, 7, 1, 1]), /// &[2, 3], /// ).unwrap(); /// let k = 2_f64; /// let result = recip(&x, 1.0, k); - /// let expected = Tensor::::new(Some(&[1, 2, 1, 0, 2, 2]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[1, 2, 1, 0, 2, 2]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn recip(a: &Tensor, input_scale: f64, out_scale: f64) -> Tensor { + pub fn recip(a: &Tensor, input_scale: f64, out_scale: f64) -> Tensor { a.par_enum_map(|_, a_i| { let rescaled = (a_i as f64) / input_scale; let denom = (1_f64) / (rescaled + f64::EPSILON); let d_inv_x = out_scale * denom; - Ok::<_, TensorError>(d_inv_x.round() as i128) + Ok::<_, TensorError>(d_inv_x.round() as i64) }) .unwrap() } @@ -2238,17 +2238,17 @@ pub mod nonlinearities { /// use ezkl::tensor::ops::nonlinearities::zero_recip; /// let k = 2_f64; /// let result = zero_recip(1.0); - /// let expected = Tensor::::new(Some(&[4503599627370496]), &[1]).unwrap(); + /// let expected = Tensor::::new(Some(&[4503599627370496]), &[1]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn zero_recip(out_scale: f64) -> Tensor { - let a = Tensor::::new(Some(&[0]), &[1]).unwrap(); + pub fn zero_recip(out_scale: f64) -> Tensor { + let a = Tensor::::new(Some(&[0]), &[1]).unwrap(); a.par_enum_map(|_, a_i| { let rescaled = a_i as f64; let denom = (1_f64) / (rescaled + f64::EPSILON); let d_inv_x = out_scale * denom; - Ok::<_, TensorError>(d_inv_x.round() as i128) + Ok::<_, TensorError>(d_inv_x.round() as i64) }) .unwrap() } @@ -2262,17 +2262,17 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::greater_than; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[2, 1, 2, 7, 1, 1]), /// &[2, 3], /// ).unwrap(); /// let k = 2.0; /// let result = greater_than(&x, k); - /// let expected = Tensor::::new(Some(&[0, 0, 0, 1, 0, 0]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[0, 0, 0, 1, 0, 0]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn greater_than(a: &Tensor, b: f64) -> Tensor { - a.par_enum_map(|_, a_i| Ok::<_, TensorError>(i128::from((a_i as f64 - b) > 0_f64))) + pub fn greater_than(a: &Tensor, b: f64) -> Tensor { + a.par_enum_map(|_, a_i| Ok::<_, TensorError>(i64::from((a_i as f64 - b) > 0_f64))) .unwrap() } @@ -2285,17 +2285,17 @@ pub mod nonlinearities { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::greater_than_equal; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[2, 1, 2, 7, 1, 1]), /// &[2, 3], /// ).unwrap(); /// let k = 2.0; /// let result = greater_than_equal(&x, k); - /// let expected = Tensor::::new(Some(&[1, 0, 1, 1, 0, 0]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[1, 0, 1, 1, 0, 0]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn greater_than_equal(a: &Tensor, b: f64) -> Tensor { - a.par_enum_map(|_, a_i| Ok::<_, TensorError>(i128::from((a_i as f64 - b) >= 0_f64))) + pub fn greater_than_equal(a: &Tensor, b: f64) -> Tensor { + a.par_enum_map(|_, a_i| Ok::<_, TensorError>(i64::from((a_i as f64 - b) >= 0_f64))) .unwrap() } @@ -2308,18 +2308,18 @@ pub mod nonlinearities { /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::less_than; /// - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[2, 1, 2, 7, 1, 1]), /// &[2, 3], /// ).unwrap(); /// let k = 2.0; /// /// let result = less_than(&x, k); - /// let expected = Tensor::::new(Some(&[0, 1, 0, 0, 1, 1]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[0, 1, 0, 0, 1, 1]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn less_than(a: &Tensor, b: f64) -> Tensor { - a.par_enum_map(|_, a_i| Ok::<_, TensorError>(i128::from((a_i as f64 - b) < 0_f64))) + pub fn less_than(a: &Tensor, b: f64) -> Tensor { + a.par_enum_map(|_, a_i| Ok::<_, TensorError>(i64::from((a_i as f64 - b) < 0_f64))) .unwrap() } @@ -2332,18 +2332,18 @@ pub mod nonlinearities { /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::nonlinearities::less_than_equal; /// - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[2, 1, 2, 7, 1, 1]), /// &[2, 3], /// ).unwrap(); /// let k = 2.0; /// /// let result = less_than_equal(&x, k); - /// let expected = Tensor::::new(Some(&[1, 1, 1, 0, 1, 1]), &[2, 3]).unwrap(); + /// let expected = Tensor::::new(Some(&[1, 1, 1, 0, 1, 1]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` - pub fn less_than_equal(a: &Tensor, b: f64) -> Tensor { - a.par_enum_map(|_, a_i| Ok::<_, TensorError>(i128::from((a_i as f64 - b) <= 0_f64))) + pub fn less_than_equal(a: &Tensor, b: f64) -> Tensor { + a.par_enum_map(|_, a_i| Ok::<_, TensorError>(i64::from((a_i as f64 - b) <= 0_f64))) .unwrap() } } @@ -2361,15 +2361,15 @@ pub mod accumulated { /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::accumulated::dot; /// - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[5, 2]), /// &[2], /// ).unwrap(); - /// let y = Tensor::::new( + /// let y = Tensor::::new( /// Some(&[5, 5]), /// &[2], /// ).unwrap(); - /// let expected = Tensor::::new( + /// let expected = Tensor::::new( /// Some(&[25, 35]), /// &[2], /// ).unwrap(); @@ -2409,12 +2409,12 @@ pub mod accumulated { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::accumulated::sum; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[2, 15, 2, 1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = sum(&x, 1).unwrap(); - /// let expected = Tensor::::new( + /// let expected = Tensor::::new( /// Some(&[2, 17, 19, 20, 21, 21]), /// &[6], /// ).unwrap(); @@ -2446,12 +2446,12 @@ pub mod accumulated { /// ``` /// use ezkl::tensor::Tensor; /// use ezkl::tensor::ops::accumulated::prod; - /// let x = Tensor::::new( + /// let x = Tensor::::new( /// Some(&[2, 15, 2, 1, 1, 0]), /// &[2, 3], /// ).unwrap(); /// let result = prod(&x, 1).unwrap(); - /// let expected = Tensor::::new( + /// let expected = Tensor::::new( /// Some(&[2, 30, 60, 60, 60, 0]), /// &[6], /// ).unwrap(); diff --git a/src/tensor/val.rs b/src/tensor/val.rs index 71b7123de..7081c48d7 100644 --- a/src/tensor/val.rs +++ b/src/tensor/val.rs @@ -316,9 +316,9 @@ impl From>> f } impl ValTensor { - /// Allocate a new [ValTensor::Value] from the given [Tensor] of [i128]. - pub fn from_i128_tensor(t: Tensor) -> ValTensor { - let inner = t.map(|x| ValType::Value(Value::known(i128_to_felt(x)))); + /// Allocate a new [ValTensor::Value] from the given [Tensor] of [i64]. + pub fn from_i64_tensor(t: Tensor) -> ValTensor { + let inner = t.map(|x| ValType::Value(Value::known(i64_to_felt(x)))); inner.into() } @@ -521,9 +521,9 @@ impl ValTensor { } /// Calls `int_evals` on the inner tensor. - pub fn get_int_evals(&self) -> Result, Box> { + pub fn get_int_evals(&self) -> Result, Box> { // finally convert to vector of integers - let mut integer_evals: Vec = vec![]; + let mut integer_evals: Vec = vec![]; match self { ValTensor::Value { inner: v, dims: _, .. @@ -531,25 +531,25 @@ impl ValTensor { // we have to push to an externally created vector or else vaf.map() returns an evaluation wrapped in Value<> (which we don't want) let _ = v.map(|vaf| match vaf { ValType::Value(v) => v.map(|f| { - integer_evals.push(crate::fieldutils::felt_to_i128(f)); + integer_evals.push(crate::fieldutils::felt_to_i64(f)); }), ValType::AssignedValue(v) => v.map(|f| { - integer_evals.push(crate::fieldutils::felt_to_i128(f.evaluate())); + integer_evals.push(crate::fieldutils::felt_to_i64(f.evaluate())); }), ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => { v.value_field().map(|f| { - integer_evals.push(crate::fieldutils::felt_to_i128(f.evaluate())); + integer_evals.push(crate::fieldutils::felt_to_i64(f.evaluate())); }) } ValType::Constant(v) => { - integer_evals.push(crate::fieldutils::felt_to_i128(v)); + integer_evals.push(crate::fieldutils::felt_to_i64(v)); Value::unknown() } }); } _ => return Err(Box::new(TensorError::WrongMethod)), }; - let mut tensor: Tensor = integer_evals.into_iter().into(); + let mut tensor: Tensor = integer_evals.into_iter().into(); match tensor.reshape(self.dims()) { _ => {} }; diff --git a/src/wasm.rs b/src/wasm.rs index d26b3b080..4a3edfa16 100644 --- a/src/wasm.rs +++ b/src/wasm.rs @@ -2,8 +2,8 @@ use crate::circuit::modules::polycommit::PolyCommitChip; use crate::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH}; use crate::circuit::modules::poseidon::PoseidonChip; use crate::circuit::modules::Module; -use crate::fieldutils::felt_to_i128; -use crate::fieldutils::i128_to_felt; +use crate::fieldutils::felt_to_i64; +use crate::fieldutils::i64_to_felt; use crate::graph::modules::POSEIDON_LEN_GRAPH; use crate::graph::quantize_float; use crate::graph::scale_to_multiplier; @@ -113,7 +113,7 @@ pub fn feltToInt( let felt: Fr = serde_json::from_slice(&array[..]) .map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?; Ok(wasm_bindgen::Clamped( - serde_json::to_vec(&felt_to_i128(felt)) + serde_json::to_vec(&felt_to_i64(felt)) .map_err(|e| JsError::new(&format!("Failed to serialize integer: {}", e)))?, )) } @@ -127,7 +127,7 @@ pub fn feltToFloat( ) -> Result { let felt: Fr = serde_json::from_slice(&array[..]) .map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?; - let int_rep = felt_to_i128(felt); + let int_rep = felt_to_i64(felt); let multiplier = scale_to_multiplier(scale); Ok(int_rep as f64 / multiplier) } @@ -141,7 +141,7 @@ pub fn floatToFelt( ) -> Result>, JsError> { let int_rep = quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?; - let felt = i128_to_felt(int_rep); + let felt = i64_to_felt(int_rep); let vec = crate::pfsys::field_to_string::(&felt); Ok(wasm_bindgen::Clamped(serde_json::to_vec(&vec).map_err( |e| JsError::new(&format!("Failed to serialize a float to felt{}", e)), @@ -275,7 +275,7 @@ pub fn genWitness( .map_err(|e| JsError::new(&format!("{}", e)))?; let witness = circuit - .forward::>(&mut input, None, None, false) + .forward::>(&mut input, None, None, false, false) .map_err(|e| JsError::new(&format!("{}", e)))?; serde_json::to_vec(&witness) diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index d0feda7b4..ceda0c674 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -3,7 +3,7 @@ mod native_tests { use ezkl::circuit::Tolerance; - use ezkl::fieldutils::{felt_to_i128, i128_to_felt}; + use ezkl::fieldutils::{felt_to_i64, i64_to_felt}; // use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD; use ezkl::graph::input::{FileSource, FileSourceInner, GraphData}; use ezkl::graph::{DataSource, GraphSettings, GraphWitness}; @@ -908,7 +908,7 @@ mod native_tests { prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, None, true, "single", Commitments::KZG, 2); #[cfg(not(feature = "icicle"))] run_js_tests(path, test.to_string(), "testWasm", false); - // test_dir.close().unwrap(); + test_dir.close().unwrap(); } #(#[test_case(WASM_TESTS[N])])* @@ -921,7 +921,7 @@ mod native_tests { prove_and_verify(path, test.to_string(), "safe", "hashed", "private", "public", 1, None, true, "single", Commitments::KZG, 2); #[cfg(not(feature = "icicle"))] run_js_tests(path, test.to_string(), "testWasm", false); - // test_dir.close().unwrap(); + test_dir.close().unwrap(); } #(#[test_case(WASM_TESTS[N])])* @@ -1373,7 +1373,7 @@ mod native_tests { let witness = witness.clone(); let outputs = witness.outputs.clone(); - // get values as i128 + // get values as i64 let output_perturbed_safe: Vec> = outputs .iter() .map(|sv| { @@ -1383,10 +1383,10 @@ mod native_tests { let perturbation = if v == &halo2curves::bn256::Fr::zero() { halo2curves::bn256::Fr::zero() } else { - i128_to_felt( - (felt_to_i128(*v) as f32 + i64_to_felt( + (felt_to_i64(*v) as f32 * (rand::thread_rng().gen_range(-0.01..0.01) * tolerance)) - as i128, + as i64, ) }; @@ -1396,7 +1396,7 @@ mod native_tests { }) .collect::>(); - // get values as i128 + // get values as i64 let output_perturbed_bad: Vec> = outputs .iter() .map(|sv| { @@ -1406,10 +1406,10 @@ mod native_tests { let perturbation = if v == &halo2curves::bn256::Fr::zero() { halo2curves::bn256::Fr::from(2) } else { - i128_to_felt( - (felt_to_i128(*v) as f32 + i64_to_felt( + (felt_to_i64(*v) as f32 * (rand::thread_rng().gen_range(0.02..0.1) * tolerance)) - as i128, + as i64, ) }; *v + perturbation diff --git a/tests/wasm.rs b/tests/wasm.rs index d06b9e70c..8e2e9541e 100644 --- a/tests/wasm.rs +++ b/tests/wasm.rs @@ -150,10 +150,10 @@ mod wasm32 { .unwrap(); assert_eq!(floating_point, (i as f64) / 4.0); - let integer: i128 = + let integer: i64 = serde_json::from_slice(&feltToInt(clamped.clone()).map_err(|_| "failed").unwrap()) .unwrap(); - assert_eq!(integer, i as i128); + assert_eq!(integer, i as i64); let hex_string = format!("{:?}", field_element.clone()); let returned_string: String = feltToBigEndian(clamped.clone()) diff --git a/tests/wasm/model.compiled b/tests/wasm/model.compiled index d36215c15..9e6eeeeed 100644 Binary files a/tests/wasm/model.compiled and b/tests/wasm/model.compiled differ