diff --git a/.cargo/config b/.cargo/config.toml similarity index 100% rename from .cargo/config rename to .cargo/config.toml diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index f97f7fd65..5787c7125 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -688,3 +688,66 @@ jobs: run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_ # - name: Reusable verifier tutorial # run: source .env/bin/activate; cargo nextest run py_tests::tests::reusable_ + + ios-integration-tests: + runs-on: macos-latest + steps: + - uses: actions/checkout@v4 + - uses: actions-rs/toolchain@v1 + with: + toolchain: nightly-2024-07-18 + override: true + components: rustfmt, clippy + - uses: baptiste0928/cargo-install@v1 + with: + crate: cargo-nextest + locked: true + - name: Run ios tests + run: CARGO_BUILD_TARGET=aarch64-apple-darwin RUSTUP_TOOLCHAIN=nightly-2024-07-18-aarch64-apple-darwin cargo test --test ios_integration_tests --features ios-bindings-test --no-default-features + + swift-package-tests: + runs-on: macos-latest + needs: [ios-integration-tests] + + steps: + - uses: actions/checkout@v4 + - uses: actions-rs/toolchain@v1 + with: + toolchain: nightly-2024-07-18 + override: true + components: rustfmt, clippy + - name: Build EzklCoreBindings + run: CONFIGURATION=debug cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features + + - name: Clone ezkl-swift- repository + run: | + git clone https://github.com/zkonduit/ezkl-swift-package.git + + - name: Copy EzklCoreBindings + run: | + rm -rf ezkl-swift-package/Sources/EzklCoreBindings + cp -r build/EzklCoreBindings ezkl-swift-package/Sources/ + + - name: Set up Xcode environment + run: | + sudo xcode-select -s /Applications/Xcode.app/Contents/Developer + sudo xcodebuild -license accept + + - name: Run Package Tests + run: | + cd ezkl-swift-package + xcodebuild test \ + -scheme EzklPackage \ + -destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \ + -resultBundlePath ../testResults + + - name: Run Example App Tests + run: | + cd ezkl-swift-package/Example + xcodebuild test \ + -project Example.xcodeproj \ + -scheme EzklApp \ + -destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \ + -parallel-testing-enabled NO \ + -resultBundlePath ../../exampleTestResults \ + -skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder \ No newline at end of file diff --git a/.github/workflows/update-ios-package.yml b/.github/workflows/update-ios-package.yml new file mode 100644 index 000000000..8ae294221 --- /dev/null +++ b/.github/workflows/update-ios-package.yml @@ -0,0 +1,75 @@ +name: Build and Publish EZKL iOS SPM package + +on: + workflow_dispatch: + inputs: + tag: + description: "The tag to release" + required: true + push: + tags: + - "*" + +jobs: + build-and-update: + runs-on: macos-latest + + steps: + - name: Checkout EZKL + uses: actions/checkout@v3 + + - name: Install Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: nightly + override: true + + - name: Build EzklCoreBindings + run: CONFIGURATION=release cargo run --bin ios_gen_bindings --features "ios-bindings uuid camino uniffi_bindgen" --no-default-features + + - name: Clone ezkl-swift-package repository + run: | + git clone https://github.com/zkonduit/ezkl-swift-package.git + + - name: Copy EzklCoreBindings + run: | + rm -rf ezkl-swift-package/Sources/EzklCoreBindings + cp -r build/EzklCoreBindings ezkl-swift-package/Sources/ + + - name: Set up Xcode environment + run: | + sudo xcode-select -s /Applications/Xcode.app/Contents/Developer + sudo xcodebuild -license accept + + - name: Run Package Tests + run: | + cd ezkl-swift-package + xcodebuild test \ + -scheme EzklPackage \ + -destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \ + -resultBundlePath ../testResults + + - name: Run Example App Tests + run: | + cd ezkl-swift-package/Example + xcodebuild test \ + -project Example.xcodeproj \ + -scheme EzklApp \ + -destination 'platform=iOS Simulator,name=iPhone 15 Pro,OS=17.5' \ + -parallel-testing-enabled NO \ + -resultBundlePath ../../exampleTestResults \ + -skip-testing:EzklAppUITests/EzklAppUITests/testButtonClicksInOrder + + - name: Commit and Push Changes to feat/ezkl-direct-integration + run: | + cd ezkl-swift-package + git config user.name "GitHub Action" + git config user.email "action@github.com" + git add Sources/EzklCoreBindings + git commit -m "Automatically updated EzklCoreBindings for EZKL" + git tag ${{ github.event.inputs.tag }} + git remote set-url origin https://zkonduit:${EZKL_PORTER_TOKEN}@github.com/zkonduit/ezkl-swift-package.git + git push origin + git push origin --tags + env: + EZKL_PORTER_TOKEN: ${{ secrets.EZKL_PORTER_TOKEN }} \ No newline at end of file diff --git a/.gitignore b/.gitignore index cbebdcf1e..f0117a065 100644 --- a/.gitignore +++ b/.gitignore @@ -46,7 +46,7 @@ var/ node_modules /dist timingData.json -!tests/wasm/pk.key -!tests/wasm/vk.key +!tests/assets/pk.key +!tests/assets/vk.key docs/python/build -!tests/wasm/vk_aggr.key \ No newline at end of file +!tests/assets/vk_aggr.key \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index cdc2e8090..52057da39 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1158,6 +1158,38 @@ dependencies = [ "serde", ] +[[package]] +name = "camino" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b96ec4966b5813e2c0507c1f86115c8c5abaadc3980879c3424042a02fd1ad3" +dependencies = [ + "serde", +] + +[[package]] +name = "cargo-platform" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24b1f0365a6c6bb4020cd05806fd0d33c44d38046b8bd7f0e40814b9763cabfc" +dependencies = [ + "serde", +] + +[[package]] +name = "cargo_metadata" +version = "0.15.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eee4243f1f26fc7a42710e7439c149e2b10b05472f88090acce52632f231a73a" +dependencies = [ + "camino", + "cargo-platform", + "semver 1.0.22", + "serde", + "serde_json", + "thiserror", +] + [[package]] name = "cast" version = "0.3.0" @@ -1861,6 +1893,7 @@ version = "0.0.0" dependencies = [ "alloy", "bincode", + "camino", "chrono", "clap", "clap_complete", @@ -1916,7 +1949,10 @@ dependencies = [ "tokio-postgres", "tosubcommand", "tract-onnx", + "uniffi", + "uniffi_bindgen", "unzip-n", + "uuid", "wasm-bindgen", "wasm-bindgen-console-logger", "wasm-bindgen-rayon", @@ -2105,6 +2141,15 @@ dependencies = [ "yansi 1.0.1", ] +[[package]] +name = "fs-err" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88a41f105fe1d5b6b34b2055e3dc59bb79b46b48b2040b9e6c7b4b5de097aa41" +dependencies = [ + "autocfg", +] + [[package]] name = "fs4" version = "0.9.1" @@ -2268,6 +2313,17 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +[[package]] +name = "goblin" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b363a30c165f666402fe6a3024d3bec7ebc898f96a4a23bd1c99f8dbf3f4f47" +dependencies = [ + "log", + "plain", + "scroll", +] + [[package]] name = "group" version = "0.13.0" @@ -3755,6 +3811,12 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +[[package]] +name = "plain" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" + [[package]] name = "plotters" version = "0.3.6" @@ -4690,6 +4752,26 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scroll" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ab8598aa408498679922eff7fa985c25d58a90771bd6be794434c5277eab1a6" +dependencies = [ + "scroll_derive", +] + +[[package]] +name = "scroll_derive" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f81c2fde025af7e69b1d1420531c8a8811ca898919db177141a85313b1cb932" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.53", +] + [[package]] name = "sec1" version = "0.7.3" @@ -4935,6 +5017,12 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "smawk" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c388c1b5e93756d0c740965c41e8822f866621d41acbdf6336a6a168f8840c" + [[package]] name = "snark-verifier" version = "0.1.1" @@ -5256,6 +5344,15 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "textwrap" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9" +dependencies = [ + "smawk", +] + [[package]] name = "thiserror" version = "1.0.58" @@ -5463,6 +5560,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "toml" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4f7f0dd8d50a853a531c426359045b1998f04219d88799810762cd4ad314234" +dependencies = [ + "serde", +] + [[package]] name = "toml_datetime" version = "0.6.5" @@ -5779,6 +5885,134 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" +[[package]] +name = "uniffi" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31bff6daf87277a9014bcdefbc2842b0553392919d1096843c5aad899ca4588" +dependencies = [ + "anyhow", + "uniffi_bindgen", + "uniffi_build", + "uniffi_core", + "uniffi_macros", +] + +[[package]] +name = "uniffi_bindgen" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96061d7e01b185aa405f7c9b134741ab3e50cc6796a47d6fd8ab9a5364b5feed" +dependencies = [ + "anyhow", + "askama", + "camino", + "cargo_metadata", + "fs-err", + "glob", + "goblin", + "heck 0.5.0", + "once_cell", + "paste", + "serde", + "textwrap", + "toml", + "uniffi_meta", + "uniffi_testing", + "uniffi_udl", +] + +[[package]] +name = "uniffi_build" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d6b86f9b221046af0c533eafe09ece04e2f1ded04ccdc9bba0ec09aec1c52bd" +dependencies = [ + "anyhow", + "camino", + "uniffi_bindgen", +] + +[[package]] +name = "uniffi_checksum_derive" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a22dbe67c1c957ac6e7611bdf605a6218aa86b0eebeb8be58b70ae85ad7d73dc" +dependencies = [ + "quote", + "syn 2.0.53", +] + +[[package]] +name = "uniffi_core" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3210d57d6ab6065ab47a2898dacdb7c606fd6a4156196831fa3bf82e34ac58a6" +dependencies = [ + "anyhow", + "bytes", + "camino", + "log", + "once_cell", + "paste", + "static_assertions", +] + +[[package]] +name = "uniffi_macros" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b58691741080935437dc862122e68d7414432a11824ac1137868de46181a0bd2" +dependencies = [ + "bincode", + "camino", + "fs-err", + "once_cell", + "proc-macro2", + "quote", + "serde", + "syn 2.0.53", + "toml", + "uniffi_meta", +] + +[[package]] +name = "uniffi_meta" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7663eacdbd9fbf4a88907ddcfe2e6fa85838eb6dc2418a7d91eebb3786f8e20b" +dependencies = [ + "anyhow", + "bytes", + "siphasher", + "uniffi_checksum_derive", +] + +[[package]] +name = "uniffi_testing" +version = "0.28.0" +source = "git+https://github.com/ElusAegis/uniffi-rs?branch=feat/testing-feature-build-fix#4684b9e7da2d9c964c2b3a73883947aab7370a06" +dependencies = [ + "anyhow", + "camino", + "cargo_metadata", + "fs-err", + "once_cell", +] + +[[package]] +name = "uniffi_udl" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cef408229a3a407fafa4c36dc4f6ece78a6fb258ab28d2b64bddd49c8cb680f6" +dependencies = [ + "anyhow", + "textwrap", + "uniffi_meta", + "uniffi_testing", + "weedle2", +] + [[package]] name = "unindent" version = "0.2.3" @@ -5829,6 +6063,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "uuid" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +dependencies = [ + "getrandom", +] + [[package]] name = "valuable" version = "0.1.0" @@ -6066,6 +6309,15 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "weedle2" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "998d2c24ec099a87daf9467808859f9d82b61f1d9c9701251aea037f514eae0e" +dependencies = [ + "nom", +] + [[package]] name = "whoami" version = "1.5.1" diff --git a/Cargo.toml b/Cargo.toml index 07ea09d1a..577d5470c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ cargo-features = ["profile-rustflags"] name = "ezkl" version = "0.0.0" edition = "2021" +default-run = "ezkl" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -11,7 +12,7 @@ edition = "2021" # Name to be imported within python # Example: import ezkl name = "ezkl" -crate-type = ["cdylib", "rlib"] +crate-type = ["cdylib", "rlib", "staticlib"] [dependencies] @@ -19,34 +20,35 @@ halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "ac/option halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "b753a832e92d5c86c5c997327a9cf9de86a18851", features = [ "derive_serde", ] } -halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/cache-lookup-commitments#8b13a0d2a7a34d8daab010dadb2c47dfa47d37d0", package = "halo2_proofs", branch = "ac/cache-lookup-commitments" } -rand = { version = "0.8", default_features = false } -itertools = { version = "0.10.3", default_features = false } -clap = { version = "4.5.3", features = ["derive"] } -serde = { version = "1.0.126", features = ["derive"], optional = true } -serde_json = { version = "1.0.97", default_features = false, features = [ - "float_roundtrip", - "raw_value", -], optional = true } -clap_complete = "4.5.2" -log = { version = "0.4.17", default_features = false, optional = true } -thiserror = { version = "1.0.38", default_features = false } -hex = { version = "0.4.3", default_features = false } +halo2_proofs = { git = "https://github.com/zkonduit/halo2", package = "halo2_proofs", branch = "ac/cache-lookup-commitments", features = ["circuit-params"] } +rand = { version = "0.8", default-features = false } +itertools = { version = "0.10.3", default-features = false } +clap = { version = "4.5.3", features = ["derive"], optional = true } +serde = { version = "1.0.126", features = ["derive"] } +clap_complete = { version = "4.5.2", optional = true } +log = { version = "0.4.17", default-features = false } +thiserror = { version = "1.0.38", default-features = false } +hex = { version = "0.4.3", default-features = false } halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/chunked-mv-lookup", package = "ecc" } snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [ "derive_serde", ] } -halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "ac/update-h2-curves" } -maybe-rayon = { version = "0.1.1", default_features = false } -bincode = { version = "1.3.3", default_features = false } +halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "ac/update-h2-curves", optional = true } +maybe-rayon = { version = "0.1.1", default-features = false } +bincode = { version = "1.3.3", default-features = false } unzip-n = "0.1.2" num = "0.4.1" -portable-atomic = "1.6.0" -tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" } -semver = "1.0.22" +portable-atomic = { version = "1.6.0", optional = true } +tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand", optional = true } +semver = { version = "1.0.22", optional = true } -# evm related deps [target.'cfg(not(target_arch = "wasm32"))'.dependencies] +serde_json = { version = "1.0.97", features = [ + "float_roundtrip", + "raw_value", +] } + +# evm related deps alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", rev = "5fbf57bac99edef9d8475190109a7ea9fb7e5e83", features = [ "provider-http", "signers", @@ -54,52 +56,48 @@ alloy = { git = "https://github.com/alloy-rs/alloy", version = "0.1.0", rev = "5 "rpc-types-eth", "signer-wallet", "node-bindings", -] } -foundry-compilers = { version = "0.4.1", features = ["svm-solc"] } -ethabi = "18" -indicatif = { version = "0.17.5", features = ["rayon"] } -gag = { version = "1.0.0", default_features = false } +], optional = true } +foundry-compilers = { version = "0.4.1", features = ["svm-solc"], optional = true } +ethabi = { version = "18", optional = true } +indicatif = { version = "0.17.5", features = ["rayon"], optional = true } +gag = { version = "1.0.0", default-features = false, optional = true } instant = { version = "0.1" } -reqwest = { version = "0.12.4", default-features = false, features = [ - "default-tls", - "multipart", - "stream", -] } -openssl = { version = "0.10.55", features = ["vendored"] } -tokio-postgres = "0.7.10" -pg_bigdecimal = "0.1.5" -lazy_static = "1.4.0" -colored_json = { version = "3.0.1", default_features = false, optional = true } -regex = { version = "1", default_features = false } -tokio = { version = "1.35.0", default_features = false, features = [ - "macros", - "rt-multi-thread", -] } -pyo3 = { version = "0.21.2", features = [ - "extension-module", - "abi3-py37", - "macros", -], default_features = false, optional = true } -pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch = "migration-pyo3-0.21", features = [ - "attributes", - "tokio-runtime", -], default_features = false, optional = true } - -pyo3-log = { version = "0.10.0", default_features = false, optional = true } -tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default_features = false, optional = true } +reqwest = { version = "0.12.4", default-features = false, features = ["default-tls", "multipart", "stream"], optional = true } +openssl = { version = "0.10.55", features = ["vendored"], optional = true } +tokio-postgres = { version = "0.7.10", optional = true } +pg_bigdecimal = { version = "0.1.5", optional = true } +lazy_static = { version = "1.4.0", optional = true } +colored_json = { version = "3.0.1", default-features = false, optional = true } +regex = { version = "1", default-features = false, optional = true } +tokio = { version = "1.35.0", default-features = false, features = ["macros", "rt-multi-thread"], optional = true } +pyo3 = { version = "0.21.2", features = ["extension-module", "abi3-py37", "macros"], default-features = false, optional = true } +pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch="migration-pyo3-0.21", features = ["attributes", "tokio-runtime"], default-features = false, optional = true } +pyo3-log = { version = "0.10.0", default-features = false, optional = true } +tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default-features = false, optional = true } tabled = { version = "0.12.0", optional = true } metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true } objc = { version = "0.2.4", optional = true } -mimalloc = "0.1" +mimalloc = { version = "0.1", optional = true } + +# universal bindings +uniffi = { version = "=0.28.0", optional = true } +getrandom = { version = "0.2.8", optional = true } +uniffi_bindgen = { version = "=0.28.0", optional = true } +camino = { version = "^1.1", optional = true } +uuid = { version = "1.10.0", features = ["v4"], optional = true } [target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies] -colored = { version = "2.0.0", default_features = false, optional = true } -env_logger = { version = "0.10.0", default_features = false, optional = true } -chrono = "0.4.31" -sha256 = "1.4.0" +colored = { version = "2.0.0", default-features = false, optional = true } +env_logger = { version = "0.10.0", default-features = false, optional = true } +chrono = { version = "0.4.31", optional = true } +sha256 = { version = "1.4.0", optional = true } [target.'cfg(target_arch = "wasm32")'.dependencies] +serde_json = { version = "1.0.97", default-features = false, features = [ + "float_roundtrip", + "raw_value", +] } getrandom = { version = "0.2.8", features = ["js"] } instant = { version = "0.1", features = ["wasm-bindgen", "inaccurate"] } @@ -115,6 +113,10 @@ wasm-bindgen-console-logger = "0.1.1" [target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dev-dependencies] criterion = { version = "0.5.1", features = ["html_reports"] } + +[build-dependencies] +uniffi = { version = "0.28", features = ["build"], optional = true } + [dev-dependencies] tempfile = "3.3.0" lazy_static = "1.4.0" @@ -188,29 +190,47 @@ test = false bench = false required-features = ["ezkl"] +[[bin]] +name = "ios_gen_bindings" +required-features = ["ios-bindings", "uuid", "camino", "uniffi_bindgen"] + [features] web = ["wasm-bindgen-rayon"] -default = [ - "ezkl", - "mv-lookup", - "precompute-coset", - "no-banner", - "parallel-poly-read", -] +default = ["ezkl", "mv-lookup", "precompute-coset", "no-banner", "parallel-poly-read"] onnx = ["dep:tract-onnx"] python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"] +ios-bindings = ["mv-lookup", "precompute-coset", "parallel-poly-read", "uniffi"] +ios-bindings-test = ["ios-bindings", "uniffi/bindgen-tests"] ezkl = [ "onnx", - "serde", - "serde_json", - "log", - "colored", - "env_logger", + "dep:colored", + "dep:env_logger", "tabled/color", + "serde_json/std", "colored_json", - "halo2_proofs/circuit-params", + "dep:alloy", + "dep:foundry-compilers", + "dep:ethabi", + "dep:indicatif", + "dep:gag", + "dep:reqwest", + "dep:openssl", + "dep:tokio-postgres", + "dep:pg_bigdecimal", + "dep:lazy_static", + "dep:regex", + "dep:tokio", + "dep:mimalloc", + "dep:chrono", + "dep:sha256", + "dep:portable-atomic", + "dep:clap_complete", + "dep:halo2_solidity_verifier", + "dep:semver", + "dep:clap", + "dep:tosubcommand", ] -parallel-poly-read = ["halo2_proofs/parallel-poly-read"] +parallel-poly-read = ["halo2_proofs/circuit-params", "halo2_proofs/parallel-poly-read"] mv-lookup = [ "halo2_proofs/mv-lookup", "snark-verifier/mv-lookup", @@ -224,7 +244,6 @@ empty-cmd = [] no-banner = [] no-update = [] - # icicle patch to 0.1.0 if feature icicle is enabled [patch.'https://github.com/ingonyama-zk/icicle'] icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" } @@ -232,8 +251,12 @@ icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = [patch.'https://github.com/zkonduit/halo2'] halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/cache-lookup-commitments#8b13a0d2a7a34d8daab010dadb2c47dfa47d37d0", package = "halo2_proofs", branch = "ac/cache-lookup-commitments" } +[patch.crates-io] +uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" } + [profile.release] rustflags = ["-C", "relocation-model=pic"] lto = "fat" codegen-units = 1 # panic = "abort" + diff --git a/build.rs b/build.rs new file mode 100644 index 000000000..31bdb681c --- /dev/null +++ b/build.rs @@ -0,0 +1,7 @@ +fn main() { + if cfg!(feature = "ios-bindings-test") { + println!("cargo::rustc-env=UNIFFI_CARGO_BUILD_EXTRA_ARGS=--features=ios-bindings --no-default-features"); + } + + println!("cargo::rerun-if-changed=build.rs"); +} diff --git a/examples/conv2d_mnist/main.rs b/examples/conv2d_mnist/main.rs index 702770440..dd56a3c98 100644 --- a/examples/conv2d_mnist/main.rs +++ b/examples/conv2d_mnist/main.rs @@ -285,7 +285,7 @@ where } pub fn runconv() { - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] env_logger::init(); const KERNEL_HEIGHT: usize = 5; diff --git a/examples/mlp_4d_einsum.rs b/examples/mlp_4d_einsum.rs index 81f310b0f..5dadfe16b 100644 --- a/examples/mlp_4d_einsum.rs +++ b/examples/mlp_4d_einsum.rs @@ -220,7 +220,7 @@ impl = Tensor::::new( diff --git a/src/bin/ezkl.rs b/src/bin/ezkl.rs index 44c64cd32..41d0a0fe5 100644 --- a/src/bin/ezkl.rs +++ b/src/bin/ezkl.rs @@ -1,28 +1,28 @@ // ignore file if compiling for wasm #[global_allocator] -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use clap::{CommandFactory, Parser}; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use colored_json::ToColoredJson; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use ezkl::commands::Cli; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use ezkl::execute::run; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use ezkl::logger::init_logger; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use log::{error, info}; #[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))] use rand::prelude::SliceRandom; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] #[cfg(feature = "icicle")] use std::env; #[tokio::main(flavor = "current_thread")] -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] pub async fn main() { let args = Cli::parse(); @@ -59,7 +59,7 @@ pub async fn main() { } } -#[cfg(target_arch = "wasm32")] +#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))] pub fn main() {} #[cfg(not(any(target_arch = "wasm32", feature = "no-banner")))] diff --git a/src/bin/ios_gen_bindings.rs b/src/bin/ios_gen_bindings.rs new file mode 100644 index 000000000..e1bdd6b86 --- /dev/null +++ b/src/bin/ios_gen_bindings.rs @@ -0,0 +1,269 @@ +use camino::Utf8Path; +use std::fs; +use std::fs::remove_dir_all; +use std::path::{Path, PathBuf}; +use std::process::Command; +use uniffi_bindgen::bindings::SwiftBindingGenerator; +use uniffi_bindgen::library_mode::generate_bindings; +use uuid::Uuid; + +fn main() { + let library_name = std::env::var("CARGO_PKG_NAME").expect("CARGO_PKG_NAME is not set"); + let mode = determine_build_mode(); + build_bindings(&library_name, mode); +} + +/// Determines the build mode based on the CONFIGURATION environment variable. +/// Defaults to "release" if not set or unrecognized. +/// "release" mode takes longer to build but produces optimized code, which has smaller size and is faster. +fn determine_build_mode() -> &'static str { + match std::env::var("CONFIGURATION").map(|s| s.to_lowercase()) { + Ok(ref config) if config == "debug" => "debug", + _ => "release", + } +} + +/// Builds the Swift bindings and XCFramework for the specified library and build mode. +fn build_bindings(library_name: &str, mode: &str) { + // Get the root directory of this Cargo project + let manifest_dir = std::env::var_os("CARGO_MANIFEST_DIR") + .map(PathBuf::from) + .unwrap_or_else(|| std::env::current_dir().unwrap()); + + // Define the build directory inside the manifest directory + let build_dir = manifest_dir.join("build"); + + // Create a temporary directory to store the bindings and combined library + let tmp_dir = mktemp_local(&build_dir); + + // Define directories for Swift bindings and output bindings + let swift_bindings_dir = tmp_dir.join("SwiftBindings"); + let bindings_out = create_bindings_out_dir(&tmp_dir); + let framework_out = bindings_out.join("EzklCore.xcframework"); + + // Define target architectures for building + // We currently only support iOS devices and simulators running on ARM Macs + // This is due to limiting the library size to under 100MB for GitHub Commit Size Limit + // To support older Macs (Intel), follow the instructions in the comments below + #[allow(clippy::useless_vec)] + let target_archs = vec![ + vec!["aarch64-apple-ios"], // iOS device + vec!["aarch64-apple-ios-sim"], // iOS simulator ARM Mac + // vec!["aarch64-apple-ios-sim", "x86_64-apple-ios"], // TODO - replace the above line with this line to allow running on older Macs (Intel) + ]; + + // Build the library for each architecture and combine them + let out_lib_paths: Vec = target_archs + .iter() + .map(|archs| build_combined_archs(library_name, archs, &build_dir, mode)) + .collect(); + + // Generate the path to the built dynamic library (.dylib) + let out_dylib_path = build_dir.join(format!( + "{}/{}/lib{}.dylib", + target_archs[0][0], mode, library_name + )); + + // Generate Swift bindings using uniffi_bindgen + generate_ios_bindings(&out_dylib_path, &swift_bindings_dir) + .expect("Failed to generate iOS bindings"); + + // Move the generated Swift file to the bindings output directory + fs::rename( + swift_bindings_dir.join(format!("{}.swift", library_name)), + bindings_out.join("EzklCore.swift"), + ) + .expect("Failed to copy swift bindings file"); + + // Rename the `ios_ezklFFI.modulemap` file to `module.modulemap` + fs::rename( + swift_bindings_dir.join(format!("{}FFI.modulemap", library_name)), + swift_bindings_dir.join("module.modulemap"), + ) + .expect("Failed to rename modulemap file"); + + // Create the XCFramework from the combined libraries and Swift bindings + create_xcframework(&out_lib_paths, &swift_bindings_dir, &framework_out); + + // Define the destination directory for the bindings + let bindings_dest = build_dir.join("EzklCoreBindings"); + if bindings_dest.exists() { + fs::remove_dir_all(&bindings_dest).expect("Failed to remove existing bindings directory"); + } + + // Move the bindings output to the destination directory + fs::rename(&bindings_out, &bindings_dest).expect("Failed to move framework into place"); + + // Clean up temporary directories + cleanup_temp_dirs(&build_dir); +} + +/// Creates the output directory for the bindings. +/// Returns the path to the bindings output directory. +fn create_bindings_out_dir(base_dir: &Path) -> PathBuf { + let bindings_out = base_dir.join("EzklCoreBindings"); + fs::create_dir_all(&bindings_out).expect("Failed to create bindings output directory"); + bindings_out +} + +/// Builds the library for each architecture and combines them into a single library using lipo. +/// Returns the path to the combined library. +fn build_combined_archs( + library_name: &str, + archs: &[&str], + build_dir: &Path, + mode: &str, +) -> PathBuf { + // Build the library for each architecture + let out_lib_paths: Vec = archs + .iter() + .map(|&arch| { + build_for_arch(arch, build_dir, mode); + build_dir + .join(arch) + .join(mode) + .join(format!("lib{}.a", library_name)) + }) + .collect(); + + // Create a unique temporary directory for the combined library + let lib_out = mktemp_local(build_dir).join(format!("lib{}.a", library_name)); + + // Combine the libraries using lipo + let mut lipo_cmd = Command::new("lipo"); + lipo_cmd + .arg("-create") + .arg("-output") + .arg(lib_out.to_str().unwrap()); + for lib_path in &out_lib_paths { + lipo_cmd.arg(lib_path.to_str().unwrap()); + } + + let status = lipo_cmd.status().expect("Failed to run lipo command"); + if !status.success() { + panic!("lipo command failed with status: {}", status); + } + + lib_out +} + +/// Builds the library for a specific architecture. +fn build_for_arch(arch: &str, build_dir: &Path, mode: &str) { + // Ensure the target architecture is installed + install_arch(arch); + + // Run cargo build for the specified architecture and mode + let mut build_cmd = Command::new("cargo"); + build_cmd + .arg("build") + .arg("--no-default-features") + .arg("--features") + .arg("ios-bindings"); + + if mode == "release" { + build_cmd.arg("--release"); + } + build_cmd + .arg("--lib") + .env("CARGO_BUILD_TARGET_DIR", build_dir) + .env("CARGO_BUILD_TARGET", arch); + + let status = build_cmd.status().expect("Failed to run cargo build"); + if !status.success() { + panic!("cargo build failed for architecture: {}", arch); + } +} + +/// Installs the specified target architecture using rustup. +fn install_arch(arch: &str) { + let status = Command::new("rustup") + .arg("target") + .arg("add") + .arg(arch) + .status() + .expect("Failed to run rustup command"); + + if !status.success() { + panic!("Failed to install target architecture: {}", arch); + } +} + +/// Generates Swift bindings for the iOS library using uniffi_bindgen. +fn generate_ios_bindings(dylib_path: &Path, binding_dir: &Path) -> Result<(), std::io::Error> { + // Remove existing binding directory if it exists + if binding_dir.exists() { + remove_dir_all(binding_dir)?; + } + + // Generate the Swift bindings using uniffi_bindgen + generate_bindings( + Utf8Path::from_path(dylib_path).ok_or_else(|| { + std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid dylib path") + })?, + None, + &SwiftBindingGenerator, + None, + Utf8Path::from_path(binding_dir).ok_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Invalid Swift bindings directory", + ) + })?, + true, + ) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?; + + Ok(()) +} + +/// Creates an XCFramework from the combined libraries and Swift bindings. +fn create_xcframework(lib_paths: &[PathBuf], swift_bindings_dir: &Path, framework_out: &Path) { + let mut xcbuild_cmd = Command::new("xcodebuild"); + xcbuild_cmd.arg("-create-xcframework"); + + // Add each library and its corresponding headers to the xcodebuild command + for lib_path in lib_paths { + println!("Including library: {:?}", lib_path); + xcbuild_cmd.arg("-library"); + xcbuild_cmd.arg(lib_path.to_str().unwrap()); + xcbuild_cmd.arg("-headers"); + xcbuild_cmd.arg(swift_bindings_dir.to_str().unwrap()); + } + + xcbuild_cmd.arg("-output"); + xcbuild_cmd.arg(framework_out.to_str().unwrap()); + + let status = xcbuild_cmd.status().expect("Failed to run xcodebuild"); + if !status.success() { + panic!("xcodebuild failed with status: {}", status); + } +} + +/// Creates a temporary directory inside the build path with a unique UUID. +/// This ensures unique build artifacts for concurrent builds. +fn mktemp_local(build_path: &Path) -> PathBuf { + let dir = tmp_local(build_path).join(Uuid::new_v4().to_string()); + fs::create_dir(&dir).expect("Failed to create temporary directory"); + dir +} + +/// Gets the path to the local temporary directory inside the build path. +fn tmp_local(build_path: &Path) -> PathBuf { + let tmp_path = build_path.join("tmp"); + if let Ok(metadata) = fs::metadata(&tmp_path) { + if !metadata.is_dir() { + panic!("Expected 'tmp' to be a directory"); + } + } else { + fs::create_dir_all(&tmp_path).expect("Failed to create local temporary directory"); + } + tmp_path +} + +/// Cleans up temporary directories inside the build path. +fn cleanup_temp_dirs(build_dir: &Path) { + let tmp_dir = build_dir.join("tmp"); + if tmp_dir.exists() { + fs::remove_dir_all(tmp_dir).expect("Failed to remove temporary directories"); + } +} diff --git a/src/bindings/mod.rs b/src/bindings/mod.rs new file mode 100644 index 000000000..df4dbb81f --- /dev/null +++ b/src/bindings/mod.rs @@ -0,0 +1,12 @@ +/// Python bindings +#[cfg(feature = "python-bindings")] +pub mod python; +/// Universal bindings for all platforms +#[cfg(any( + feature = "ios-bindings", + all(target_arch = "wasm32", target_os = "unknown") +))] +pub mod universal; +/// wasm prover and verifier +#[cfg(all(target_arch = "wasm32", target_os = "unknown"))] +pub mod wasm; diff --git a/src/python.rs b/src/bindings/python.rs similarity index 100% rename from src/python.rs rename to src/bindings/python.rs diff --git a/src/bindings/universal.rs b/src/bindings/universal.rs new file mode 100644 index 000000000..a68fdf0ff --- /dev/null +++ b/src/bindings/universal.rs @@ -0,0 +1,579 @@ +use halo2_proofs::{ + plonk::*, + poly::{ + commitment::{CommitmentScheme, ParamsProver}, + ipa::{ + commitment::{IPACommitmentScheme, ParamsIPA}, + multiopen::{ProverIPA, VerifierIPA}, + strategy::SingleStrategy as IPASingleStrategy, + }, + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + multiopen::{ProverSHPLONK, VerifierSHPLONK}, + strategy::SingleStrategy as KZGSingleStrategy, + }, + VerificationStrategy, + }, +}; +use std::fmt::Display; +use std::io::BufReader; +use std::str::FromStr; + +use crate::{ + circuit::region::RegionSettings, + graph::GraphSettings, + pfsys::{ + create_proof_circuit, + evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript}, + verify_proof_circuit, TranscriptType, + }, + tensor::TensorType, + CheckMode, Commitments, EZKLError as InnerEZKLError, +}; + +use crate::graph::{GraphCircuit, GraphWitness}; +use halo2_solidity_verifier::encode_calldata; +use halo2curves::{ + bn256::{Bn256, Fr, G1Affine}, + ff::{FromUniformBytes, PrimeField}, +}; +use snark_verifier::{loader::native::NativeLoader, system::halo2::transcript::evm::EvmTranscript}; + +/// Wrapper around the Error Message +#[cfg_attr(feature = "ios-bindings", derive(uniffi::Error))] +#[derive(Debug)] +pub enum EZKLError { + /// Some Comment + InternalError(String), +} + +impl Display for EZKLError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + EZKLError::InternalError(e) => write!(f, "Internal error: {}", e), + } + } +} + +impl From for EZKLError { + fn from(e: InnerEZKLError) -> Self { + EZKLError::InternalError(e.to_string()) + } +} + +/// Encode verifier calldata from proof and ethereum vk_address +#[cfg_attr(feature = "ios-bindings", uniffi::export)] +pub(crate) fn encode_verifier_calldata( + // TODO - shuold it be pub(crate) or pub or pub(super)? + proof: Vec, + vk_address: Option>, +) -> Result, EZKLError> { + let snark: crate::pfsys::Snark = + serde_json::from_slice(&proof[..]).map_err(InnerEZKLError::from)?; + + let vk_address: Option<[u8; 20]> = if let Some(vk_address) = vk_address { + let array: [u8; 20] = + serde_json::from_slice(&vk_address[..]).map_err(InnerEZKLError::from)?; + Some(array) + } else { + None + }; + + let flattened_instances = snark.instances.into_iter().flatten(); + + let encoded = encode_calldata( + vk_address, + &snark.proof, + &flattened_instances.collect::>(), + ); + + Ok(encoded) +} + +/// Generate witness from compiled circuit and input json +#[cfg_attr(feature = "ios-bindings", uniffi::export)] +pub(crate) fn gen_witness(compiled_circuit: Vec, input: Vec) -> Result, EZKLError> { + let mut circuit: crate::graph::GraphCircuit = bincode::deserialize(&compiled_circuit[..]) + .map_err(|e| { + EZKLError::InternalError(format!("Failed to deserialize compiled model: {}", e)) + })?; + let input: crate::graph::input::GraphData = serde_json::from_slice(&input[..]) + .map_err(|e| EZKLError::InternalError(format!("Failed to deserialize input: {}", e)))?; + + let mut input = circuit + .load_graph_input(&input) + .map_err(|e| EZKLError::InternalError(format!("{}", e)))?; + + let witness = circuit + .forward::>( + &mut input, + None, + None, + RegionSettings::all_true( + circuit.settings().run_args.decomp_base, + circuit.settings().run_args.decomp_legs, + ), + ) + .map_err(|e| EZKLError::InternalError(format!("{}", e)))?; + + serde_json::to_vec(&witness) + .map_err(|e| EZKLError::InternalError(format!("Failed to serialize witness: {}", e))) +} + +/// Generate verifying key from compiled circuit, and parameters srs +#[cfg_attr(feature = "ios-bindings", uniffi::export)] +pub(crate) fn gen_vk( + compiled_circuit: Vec, + srs: Vec, + compress_selectors: bool, +) -> Result, EZKLError> { + let mut reader = BufReader::new(&srs[..]); + let params: ParamsKZG = get_params(&mut reader)?; + + let circuit: GraphCircuit = bincode::deserialize(&compiled_circuit[..]) + .map_err(|e| EZKLError::InternalError(format!("Failed to deserialize circuit: {}", e)))?; + + let vk = create_vk_lean::, Fr, GraphCircuit>( + &circuit, + ¶ms, + compress_selectors, + ) + .map_err(|e| EZKLError::InternalError(format!("Failed to create verifying key: {}", e)))?; + + let mut serialized_vk = Vec::new(); + vk.write(&mut serialized_vk, halo2_proofs::SerdeFormat::RawBytes) + .map_err(|e| { + EZKLError::InternalError(format!("Failed to serialize verifying key: {}", e)) + })?; + + Ok(serialized_vk) +} + +/// Generate proving key from vk, compiled circuit and parameters srs +#[cfg_attr(feature = "ios-bindings", uniffi::export)] +pub(crate) fn gen_pk( + vk: Vec, + compiled_circuit: Vec, + srs: Vec, +) -> Result, EZKLError> { + let mut reader = BufReader::new(&srs[..]); + let params: ParamsKZG = get_params(&mut reader)?; + + let circuit: GraphCircuit = bincode::deserialize(&compiled_circuit[..]) + .map_err(|e| EZKLError::InternalError(format!("Failed to deserialize circuit: {}", e)))?; + + let mut reader = BufReader::new(&vk[..]); + let vk = VerifyingKey::::read::<_, GraphCircuit>( + &mut reader, + halo2_proofs::SerdeFormat::RawBytes, + circuit.settings().clone(), + ) + .map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?; + + let pk = create_pk_lean::, Fr, GraphCircuit>(vk, &circuit, ¶ms) + .map_err(|e| EZKLError::InternalError(format!("Failed to create proving key: {}", e)))?; + + let mut serialized_pk = Vec::new(); + pk.write(&mut serialized_pk, halo2_proofs::SerdeFormat::RawBytes) + .map_err(|e| EZKLError::InternalError(format!("Failed to serialize proving key: {}", e)))?; + + Ok(serialized_pk) +} + +/// Verify proof with vk, proof json, circuit settings json and srs +#[cfg_attr(feature = "ios-bindings", uniffi::export)] +pub(crate) fn verify( + proof: Vec, + vk: Vec, + settings: Vec, + srs: Vec, +) -> Result { + let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..]) + .map_err(|e| EZKLError::InternalError(format!("Failed to deserialize settings: {}", e)))?; + + let proof: crate::pfsys::Snark = serde_json::from_slice(&proof[..]) + .map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proof: {}", e)))?; + + let mut reader = BufReader::new(&vk[..]); + let vk = VerifyingKey::::read::<_, GraphCircuit>( + &mut reader, + halo2_proofs::SerdeFormat::RawBytes, + circuit_settings.clone(), + ) + .map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?; + + let orig_n = 1 << circuit_settings.run_args.logrows; + let commitment = circuit_settings.run_args.commitment.into(); + + let mut reader = BufReader::new(&srs[..]); + let result = match commitment { + Commitments::KZG => { + let params: ParamsKZG = get_params(&mut reader)?; + let strategy = KZGSingleStrategy::new(params.verifier_params()); + match proof.transcript_type { + TranscriptType::EVM => verify_proof_circuit::< + VerifierSHPLONK<'_, Bn256>, + KZGCommitmentScheme, + KZGSingleStrategy<_>, + _, + EvmTranscript, + >(&proof, ¶ms, &vk, strategy, orig_n), + TranscriptType::Poseidon => { + verify_proof_circuit::< + VerifierSHPLONK<'_, Bn256>, + KZGCommitmentScheme, + KZGSingleStrategy<_>, + _, + PoseidonTranscript, + >(&proof, ¶ms, &vk, strategy, orig_n) + } + } + } + Commitments::IPA => { + let params: ParamsIPA<_> = get_params(&mut reader)?; + let strategy = IPASingleStrategy::new(params.verifier_params()); + match proof.transcript_type { + TranscriptType::EVM => verify_proof_circuit::< + VerifierIPA<_>, + IPACommitmentScheme, + IPASingleStrategy<_>, + _, + EvmTranscript, + >(&proof, ¶ms, &vk, strategy, orig_n), + TranscriptType::Poseidon => { + verify_proof_circuit::< + VerifierIPA<_>, + IPACommitmentScheme, + IPASingleStrategy<_>, + _, + PoseidonTranscript, + >(&proof, ¶ms, &vk, strategy, orig_n) + } + } + } + }; + + match result { + Ok(_) => Ok(true), + Err(e) => Err(EZKLError::InternalError(format!( + "Verification failed: {}", + e + ))), + } +} + +/// Verify aggregate proof with vk, proof, circuit settings and srs +#[cfg_attr(feature = "ios-bindings", uniffi::export)] +pub(crate) fn verify_aggr( + proof: Vec, + vk: Vec, + logrows: u64, + srs: Vec, + commitment: &str, +) -> Result { + let proof: crate::pfsys::Snark = serde_json::from_slice(&proof[..]) + .map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proof: {}", e)))?; + + let mut reader = BufReader::new(&vk[..]); + let vk = VerifyingKey::::read::<_, AggregationCircuit>( + &mut reader, + halo2_proofs::SerdeFormat::RawBytes, + (), + ) + .map_err(|e| EZKLError::InternalError(format!("Failed to deserialize vk: {}", e)))?; + + let commit = Commitments::from_str(commitment) + .map_err(|e| EZKLError::InternalError(format!("Invalid commitment: {}", e)))?; + + let orig_n = 1 << logrows; + + let mut reader = BufReader::new(&srs[..]); + let result = match commit { + Commitments::KZG => { + let params: ParamsKZG = get_params(&mut reader)?; + let strategy = KZGSingleStrategy::new(params.verifier_params()); + match proof.transcript_type { + TranscriptType::EVM => verify_proof_circuit::< + VerifierSHPLONK<'_, Bn256>, + KZGCommitmentScheme, + KZGSingleStrategy<_>, + _, + EvmTranscript, + >(&proof, ¶ms, &vk, strategy, orig_n), + + TranscriptType::Poseidon => { + verify_proof_circuit::< + VerifierSHPLONK<'_, Bn256>, + KZGCommitmentScheme, + KZGSingleStrategy<_>, + _, + PoseidonTranscript, + >(&proof, ¶ms, &vk, strategy, orig_n) + } + } + } + Commitments::IPA => { + let params: ParamsIPA<_> = + halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err( + |e| EZKLError::InternalError(format!("Failed to deserialize params: {}", e)), + )?; + let strategy = IPASingleStrategy::new(params.verifier_params()); + match proof.transcript_type { + TranscriptType::EVM => verify_proof_circuit::< + VerifierIPA<_>, + IPACommitmentScheme, + IPASingleStrategy<_>, + _, + EvmTranscript, + >(&proof, ¶ms, &vk, strategy, orig_n), + TranscriptType::Poseidon => { + verify_proof_circuit::< + VerifierIPA<_>, + IPACommitmentScheme, + IPASingleStrategy<_>, + _, + PoseidonTranscript, + >(&proof, ¶ms, &vk, strategy, orig_n) + } + } + } + }; + + result + .map(|_| true) + .map_err(|e| EZKLError::InternalError(format!("{}", e))) +} + +/// Prove in browser with compiled circuit, witness json, proving key, and srs +#[cfg_attr(feature = "ios-bindings", uniffi::export)] +pub(crate) fn prove( + witness: Vec, + pk: Vec, + compiled_circuit: Vec, + srs: Vec, +) -> Result, EZKLError> { + #[cfg(feature = "det-prove")] + log::set_max_level(log::LevelFilter::Debug); + #[cfg(not(feature = "det-prove"))] + log::set_max_level(log::LevelFilter::Info); + + let mut circuit: GraphCircuit = bincode::deserialize(&compiled_circuit[..]) + .map_err(|e| EZKLError::InternalError(format!("Failed to deserialize circuit: {}", e)))?; + + let data: GraphWitness = serde_json::from_slice(&witness[..]).map_err(InnerEZKLError::from)?; + + let mut reader = BufReader::new(&pk[..]); + let pk = ProvingKey::::read::<_, GraphCircuit>( + &mut reader, + halo2_proofs::SerdeFormat::RawBytes, + circuit.settings().clone(), + ) + .map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?; + + circuit + .load_graph_witness(&data) + .map_err(InnerEZKLError::from)?; + let public_inputs = circuit + .prepare_public_inputs(&data) + .map_err(InnerEZKLError::from)?; + let proof_split_commits: Option = data.into(); + + let mut reader = BufReader::new(&srs[..]); + let commitment = circuit.settings().run_args.commitment.into(); + + let proof = match commitment { + Commitments::KZG => { + let params: ParamsKZG = + halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err( + |e| EZKLError::InternalError(format!("Failed to deserialize srs: {}", e)), + )?; + + create_proof_circuit::< + KZGCommitmentScheme, + _, + ProverSHPLONK<_>, + VerifierSHPLONK<_>, + KZGSingleStrategy<_>, + _, + EvmTranscript<_, _, _, _>, + EvmTranscript<_, _, _, _>, + >( + circuit, + vec![public_inputs], + ¶ms, + &pk, + CheckMode::UNSAFE, + Commitments::KZG, + TranscriptType::EVM, + proof_split_commits, + None, + ) + } + Commitments::IPA => { + let params: ParamsIPA<_> = + halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err( + |e| EZKLError::InternalError(format!("Failed to deserialize srs: {}", e)), + )?; + + create_proof_circuit::< + IPACommitmentScheme, + _, + ProverIPA<_>, + VerifierIPA<_>, + IPASingleStrategy<_>, + _, + EvmTranscript<_, _, _, _>, + EvmTranscript<_, _, _, _>, + >( + circuit, + vec![public_inputs], + ¶ms, + &pk, + CheckMode::UNSAFE, + Commitments::IPA, + TranscriptType::EVM, + proof_split_commits, + None, + ) + } + } + .map_err(InnerEZKLError::from)?; + + Ok(serde_json::to_vec(&proof).map_err(InnerEZKLError::from)?) +} + +/// Validate the witness json +#[cfg_attr(feature = "ios-bindings", uniffi::export)] +pub(crate) fn witness_validation(witness: Vec) -> Result { + let _: GraphWitness = serde_json::from_slice(&witness[..]).map_err(InnerEZKLError::from)?; + + Ok(true) +} + +/// Validate the compiled circuit +#[cfg_attr(feature = "ios-bindings", uniffi::export)] +pub(crate) fn compiled_circuit_validation(compiled_circuit: Vec) -> Result { + let _: GraphCircuit = bincode::deserialize(&compiled_circuit[..]).map_err(|e| { + EZKLError::InternalError(format!("Failed to deserialize compiled circuit: {}", e)) + })?; + + Ok(true) +} + +/// Validate the input json +#[cfg_attr(feature = "ios-bindings", uniffi::export)] +pub(crate) fn input_validation(input: Vec) -> Result { + let _: crate::graph::input::GraphData = + serde_json::from_slice(&input[..]).map_err(InnerEZKLError::from)?; + + Ok(true) +} + +/// Validate the proof json +#[cfg_attr(feature = "ios-bindings", uniffi::export)] +pub(crate) fn proof_validation(proof: Vec) -> Result { + let _: crate::pfsys::Snark = + serde_json::from_slice(&proof[..]).map_err(InnerEZKLError::from)?; + + Ok(true) +} + +/// Validate the verifying key given the settings json +#[cfg_attr(feature = "ios-bindings", uniffi::export)] +pub(crate) fn vk_validation(vk: Vec, settings: Vec) -> Result { + let circuit_settings: GraphSettings = + serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?; + + let mut reader = BufReader::new(&vk[..]); + let _ = VerifyingKey::::read::<_, GraphCircuit>( + &mut reader, + halo2_proofs::SerdeFormat::RawBytes, + circuit_settings, + ) + .map_err(|e| EZKLError::InternalError(format!("Failed to deserialize verifying key: {}", e)))?; + + Ok(true) +} + +/// Validate the proving key given the settings json +#[cfg_attr(feature = "ios-bindings", uniffi::export)] +pub(crate) fn pk_validation(pk: Vec, settings: Vec) -> Result { + let circuit_settings: GraphSettings = + serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?; + + let mut reader = BufReader::new(&pk[..]); + let _ = ProvingKey::::read::<_, GraphCircuit>( + &mut reader, + halo2_proofs::SerdeFormat::RawBytes, + circuit_settings, + ) + .map_err(|e| EZKLError::InternalError(format!("Failed to deserialize proving key: {}", e)))?; + + Ok(true) +} + +/// Validate the settings json +#[cfg_attr(feature = "ios-bindings", uniffi::export)] +pub(crate) fn settings_validation(settings: Vec) -> Result { + let _: GraphSettings = serde_json::from_slice(&settings[..]).map_err(InnerEZKLError::from)?; + + Ok(true) +} + +/// Validate the srs +#[cfg_attr(feature = "ios-bindings", uniffi::export)] +pub(crate) fn srs_validation(srs: Vec) -> Result { + let mut reader = BufReader::new(&srs[..]); + let _: ParamsKZG = + halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader).map_err(|e| { + EZKLError::InternalError(format!("Failed to deserialize params: {}", e)) + })?; + + Ok(true) +} + +// HELPER FUNCTIONS + +fn get_params< + Scheme: for<'a> halo2_proofs::poly::commitment::Params<'a, halo2curves::bn256::G1Affine>, +>( + mut reader: &mut BufReader<&[u8]>, +) -> Result { + halo2_proofs::poly::commitment::Params::::read(&mut reader) + .map_err(|e| EZKLError::InternalError(format!("Failed to deserialize params: {}", e))) +} + +/// Creates a [ProvingKey] for a [GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`) for the WASM target +pub fn create_vk_lean>( + circuit: &C, + params: &'_ Scheme::ParamsProver, + compress_selectors: bool, +) -> Result, halo2_proofs::plonk::Error> +where + C: Circuit, + ::Scalar: FromUniformBytes<64>, +{ + // Real proof + let empty_circuit = >::without_witnesses(circuit); + + // Initialize the verifying key + let vk = keygen_vk_custom(params, &empty_circuit, compress_selectors)?; + Ok(vk) +} +/// Creates a [ProvingKey] from a [VerifyingKey] for a [GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`) for the WASM target +pub fn create_pk_lean>( + vk: VerifyingKey, + circuit: &C, + params: &'_ Scheme::ParamsProver, +) -> Result, halo2_proofs::plonk::Error> +where + C: Circuit, + ::Scalar: FromUniformBytes<64>, +{ + // Real proof + let empty_circuit = >::without_witnesses(circuit); + + // Initialize the proving key + let pk = keygen_pk(params, vk, &empty_circuit)?; + Ok(pk) +} diff --git a/src/bindings/wasm.rs b/src/bindings/wasm.rs new file mode 100644 index 000000000..3adc41843 --- /dev/null +++ b/src/bindings/wasm.rs @@ -0,0 +1,372 @@ +use crate::{ + circuit::modules::{ + polycommit::PolyCommitChip, + poseidon::{ + spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH}, + PoseidonChip, + }, + Module, + }, + fieldutils::{felt_to_integer_rep, integer_rep_to_felt}, + graph::{ + modules::POSEIDON_LEN_GRAPH, quantize_float, scale_to_multiplier, GraphCircuit, + GraphSettings, + }, +}; +use console_error_panic_hook; +use halo2_proofs::{ + plonk::*, + poly::kzg::commitment::{KZGCommitmentScheme, ParamsKZG}, +}; +use halo2curves::{ + bn256::{Bn256, Fr, G1Affine}, + ff::PrimeField, +}; +use wasm_bindgen::prelude::*; +use wasm_bindgen_console_logger::DEFAULT_LOGGER; + +use crate::bindings::universal::{ + compiled_circuit_validation, encode_verifier_calldata, gen_pk, gen_vk, gen_witness, + input_validation, pk_validation, proof_validation, settings_validation, srs_validation, + verify_aggr, vk_validation, witness_validation, EZKLError as ExternalEZKLError, +}; +#[cfg(feature = "web")] +pub use wasm_bindgen_rayon::init_thread_pool; + +impl From for JsError { + fn from(e: ExternalEZKLError) -> Self { + JsError::new(&format!("{}", e)) + } +} + +#[wasm_bindgen] +/// Initialize logger for wasm +pub fn init_logger() { + log::set_logger(&DEFAULT_LOGGER).unwrap(); +} + +#[wasm_bindgen] +/// Initialize panic hook for wasm +pub fn init_panic_hook() { + console_error_panic_hook::set_once(); +} + +/// Wrapper around the halo2 encode call data method +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn encodeVerifierCalldata( + proof: wasm_bindgen::Clamped>, + vk_address: Option>, +) -> Result, JsError> { + encode_verifier_calldata(proof.0, vk_address).map_err(JsError::from) +} + +/// Converts a hex string to a byte array +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn feltToBigEndian(array: wasm_bindgen::Clamped>) -> Result { + let felt: Fr = serde_json::from_slice(&array[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?; + Ok(format!("{:?}", felt)) +} + +/// Converts a felt to a little endian string +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn feltToLittleEndian(array: wasm_bindgen::Clamped>) -> Result { + let felt: Fr = serde_json::from_slice(&array[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?; + let repr = serde_json::to_string(&felt).unwrap(); + let b: String = serde_json::from_str(&repr).unwrap(); + Ok(b) +} + +/// Converts a hex string to a byte array +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn feltToInt( + array: wasm_bindgen::Clamped>, +) -> Result>, JsError> { + let felt: Fr = serde_json::from_slice(&array[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?; + Ok(wasm_bindgen::Clamped( + serde_json::to_vec(&felt_to_integer_rep(felt)) + .map_err(|e| JsError::new(&format!("Failed to serialize integer: {}", e)))?, + )) +} + +/// Converts felts to a floating point element +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn feltToFloat( + array: wasm_bindgen::Clamped>, + scale: crate::Scale, +) -> Result { + let felt: Fr = serde_json::from_slice(&array[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?; + let int_rep = felt_to_integer_rep(felt); + let multiplier = scale_to_multiplier(scale); + Ok(int_rep as f64 / multiplier) +} + +/// Converts a floating point number to a hex string representing a fixed point field element +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn floatToFelt( + input: f64, + scale: crate::Scale, +) -> Result>, JsError> { + let int_rep = + quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?; + let felt = integer_rep_to_felt(int_rep); + let vec = crate::pfsys::field_to_string::(&felt); + Ok(wasm_bindgen::Clamped(serde_json::to_vec(&vec).map_err( + |e| JsError::new(&format!("Failed to serialize a float to felt{}", e)), + )?)) +} + +/// Generate a kzg commitment. +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn kzgCommit( + message: wasm_bindgen::Clamped>, + vk: wasm_bindgen::Clamped>, + settings: wasm_bindgen::Clamped>, + params_ser: wasm_bindgen::Clamped>, +) -> Result>, JsError> { + let message: Vec = serde_json::from_slice(&message[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?; + + let mut reader = std::io::BufReader::new(¶ms_ser[..]); + let params: ParamsKZG = + halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader) + .map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?; + + let mut reader = std::io::BufReader::new(&vk[..]); + let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize settings: {}", e)))?; + let vk = VerifyingKey::::read::<_, GraphCircuit>( + &mut reader, + halo2_proofs::SerdeFormat::RawBytes, + circuit_settings, + ) + .map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?; + + let output = PolyCommitChip::commit::>( + message, + (vk.cs().blinding_factors() + 1) as u32, + ¶ms, + ); + + Ok(wasm_bindgen::Clamped( + serde_json::to_vec(&output).map_err(|e| JsError::new(&format!("{}", e)))?, + )) +} + +/// Converts a buffer to vector of 4 u64s representing a fixed point field element +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn bufferToVecOfFelt( + buffer: wasm_bindgen::Clamped>, +) -> Result>, JsError> { + // Convert the buffer to a slice + let buffer: &[u8] = &buffer; + + // Divide the buffer into chunks of 64 bytes + let chunks = buffer.chunks_exact(16); + + // Get the remainder + let remainder = chunks.remainder(); + + // Add 0s to the remainder to make it 64 bytes + let mut remainder = remainder.to_vec(); + + // Collect chunks into a Vec<[u8; 16]>. + let chunks: Result, JsError> = chunks + .map(|slice| { + let array: [u8; 16] = slice + .try_into() + .map_err(|_| JsError::new("failed to slice input chunks"))?; + Ok(array) + }) + .collect(); + + let mut chunks = chunks?; + + if remainder.len() != 0 { + remainder.resize(16, 0); + // Convert the Vec to [u8; 16] + let remainder_array: [u8; 16] = remainder + .try_into() + .map_err(|_| JsError::new("failed to slice remainder"))?; + // append the remainder to the chunks + chunks.push(remainder_array); + } + + // Convert each chunk to a field element + let field_elements: Vec = chunks + .iter() + .map(|x| PrimeField::from_u128(u8_array_to_u128_le(*x))) + .collect(); + + Ok(wasm_bindgen::Clamped( + serde_json::to_vec(&field_elements) + .map_err(|e| JsError::new(&format!("Failed to serialize field elements: {}", e)))?, + )) +} + +/// Generate a poseidon hash in browser. Input message +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn poseidonHash( + message: wasm_bindgen::Clamped>, +) -> Result>, JsError> { + let message: Vec = serde_json::from_slice(&message[..]) + .map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?; + + let output = + PoseidonChip::::run( + message.clone(), + ) + .map_err(|e| JsError::new(&format!("{}", e)))?; + + Ok(wasm_bindgen::Clamped(serde_json::to_vec(&output).map_err( + |e| JsError::new(&format!("Failed to serialize poseidon hash output: {}", e)), + )?)) +} + +/// Generate a witness file from input.json, compiled model and a settings.json file. +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn genWitness( + compiled_circuit: wasm_bindgen::Clamped>, + input: wasm_bindgen::Clamped>, +) -> Result, JsError> { + gen_witness(compiled_circuit.0, input.0).map_err(JsError::from) +} + +/// Generate verifying key in browser +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn genVk( + compiled_circuit: wasm_bindgen::Clamped>, + params_ser: wasm_bindgen::Clamped>, + compress_selectors: bool, +) -> Result, JsError> { + gen_vk(compiled_circuit.0, params_ser.0, compress_selectors).map_err(JsError::from) +} + +/// Generate proving key in browser +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn genPk( + vk: wasm_bindgen::Clamped>, + compiled_circuit: wasm_bindgen::Clamped>, + params_ser: wasm_bindgen::Clamped>, +) -> Result, JsError> { + gen_pk(vk.0, compiled_circuit.0, params_ser.0).map_err(JsError::from) +} + +/// Verify proof in browser using wasm +#[wasm_bindgen] +pub fn verify( + proof_js: wasm_bindgen::Clamped>, + vk: wasm_bindgen::Clamped>, + settings: wasm_bindgen::Clamped>, + srs: wasm_bindgen::Clamped>, +) -> Result { + super::universal::verify(proof_js.0, vk.0, settings.0, srs.0).map_err(JsError::from) +} + +/// Verify aggregate proof in browser using wasm +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn verifyAggr( + proof_js: wasm_bindgen::Clamped>, + vk: wasm_bindgen::Clamped>, + logrows: u64, + srs: wasm_bindgen::Clamped>, + commitment: &str, +) -> Result { + verify_aggr(proof_js.0, vk.0, logrows, srs.0, commitment).map_err(JsError::from) +} + +/// Prove in browser using wasm +#[wasm_bindgen] +pub fn prove( + witness: wasm_bindgen::Clamped>, + pk: wasm_bindgen::Clamped>, + compiled_circuit: wasm_bindgen::Clamped>, + srs: wasm_bindgen::Clamped>, +) -> Result, JsError> { + super::universal::prove(witness.0, pk.0, compiled_circuit.0, srs.0).map_err(JsError::from) +} + +// VALIDATION FUNCTIONS + +/// Witness file validation +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn witnessValidation(witness: wasm_bindgen::Clamped>) -> Result { + witness_validation(witness.0).map_err(JsError::from) +} +/// Compiled circuit validation +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn compiledCircuitValidation( + compiled_circuit: wasm_bindgen::Clamped>, +) -> Result { + compiled_circuit_validation(compiled_circuit.0).map_err(JsError::from) +} +/// Input file validation +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn inputValidation(input: wasm_bindgen::Clamped>) -> Result { + input_validation(input.0).map_err(JsError::from) +} +/// Proof file validation +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn proofValidation(proof: wasm_bindgen::Clamped>) -> Result { + proof_validation(proof.0).map_err(JsError::from) +} +/// Vk file validation +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn vkValidation( + vk: wasm_bindgen::Clamped>, + settings: wasm_bindgen::Clamped>, +) -> Result { + vk_validation(vk.0, settings.0).map_err(JsError::from) +} +/// Pk file validation +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn pkValidation( + pk: wasm_bindgen::Clamped>, + settings: wasm_bindgen::Clamped>, +) -> Result { + pk_validation(pk.0, settings.0).map_err(JsError::from) +} +/// Settings file validation +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn settingsValidation(settings: wasm_bindgen::Clamped>) -> Result { + settings_validation(settings.0).map_err(JsError::from) +} +/// Srs file validation +#[wasm_bindgen] +#[allow(non_snake_case)] +pub fn srsValidation(srs: wasm_bindgen::Clamped>) -> Result { + srs_validation(srs.0).map_err(JsError::from) +} + +/// HELPER FUNCTIONS +pub fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 { + let mut n: u128 = 0; + for &b in arr.iter().rev() { + n <<= 8; + n |= b as u128; + } + n +} diff --git a/src/circuit/modules/polycommit.rs b/src/circuit/modules/polycommit.rs index 0a2de89d6..c9a4cfa54 100644 --- a/src/circuit/modules/polycommit.rs +++ b/src/circuit/modules/polycommit.rs @@ -219,7 +219,7 @@ mod tests { fn polycommit_chip_for_a_range_of_input_sizes() { let rng = rand::rngs::OsRng; - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] env_logger::init(); { @@ -247,7 +247,7 @@ mod tests { #[test] #[ignore] fn polycommit_chip_much_longer_input() { - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] env_logger::init(); let rng = rand::rngs::OsRng; diff --git a/src/circuit/modules/poseidon.rs b/src/circuit/modules/poseidon.rs index 5a05fe81a..f2a295a1c 100644 --- a/src/circuit/modules/poseidon.rs +++ b/src/circuit/modules/poseidon.rs @@ -560,7 +560,7 @@ mod tests { fn hash_for_a_range_of_input_sizes() { let rng = rand::rngs::OsRng; - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] env_logger::init(); { diff --git a/src/circuit/ops/chip.rs b/src/circuit/ops/chip.rs index 1316d5f98..8deb27686 100644 --- a/src/circuit/ops/chip.rs +++ b/src/circuit/ops/chip.rs @@ -14,6 +14,7 @@ use pyo3::{ types::PyString, }; use serde::{Deserialize, Serialize}; +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use tosubcommand::ToFlags; use crate::{ @@ -49,6 +50,7 @@ impl std::fmt::Display for CheckMode { } } +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] impl ToFlags for CheckMode { /// Convert the struct to a subcommand string fn to_flags(&self) -> Vec { @@ -88,6 +90,7 @@ impl std::fmt::Display for Tolerance { } } +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] impl ToFlags for Tolerance { /// Convert the struct to a subcommand string fn to_flags(&self) -> Vec { diff --git a/src/circuit/ops/region.rs b/src/circuit/ops/region.rs index c03bb638a..90f0ec134 100644 --- a/src/circuit/ops/region.rs +++ b/src/circuit/ops/region.rs @@ -3,7 +3,7 @@ use crate::{ fieldutils::IntegerRep, tensor::{Tensor, TensorType, ValTensor, ValType, VarTensor}, }; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use colored::Colorize; use halo2_proofs::{ circuit::Region, @@ -193,7 +193,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a self.settings.legs } - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] /// pub fn debug_report(&self) { log::debug!( diff --git a/src/circuit/table.rs b/src/circuit/table.rs index 95600cc57..4be1fa6df 100644 --- a/src/circuit/table.rs +++ b/src/circuit/table.rs @@ -15,7 +15,7 @@ use crate::{ tensor::{Tensor, TensorType}, }; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use crate::execute::EZKL_REPO_PATH; use crate::circuit::lookup::LookupOp; @@ -28,14 +28,14 @@ pub const RANGE_MULTIPLIER: IntegerRep = 2; /// The safety factor offset for the number of rows in the lookup table. pub const RESERVED_BLINDING_ROWS_PAD: usize = 3; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] lazy_static::lazy_static! { /// an optional directory to read and write the lookup table cache pub static ref LOOKUP_CACHE: String = format!("{}/cache", *EZKL_REPO_PATH); } /// The lookup table cache is disabled on wasm32 target. -#[cfg(target_arch = "wasm32")] +#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))] pub const LOOKUP_CACHE: &str = ""; #[derive(Debug, Clone)] diff --git a/src/circuit/tests.rs b/src/circuit/tests.rs index cb1bdf8cc..f4e4c4386 100644 --- a/src/circuit/tests.rs +++ b/src/circuit/tests.rs @@ -8,7 +8,10 @@ use halo2_proofs::{ }; use halo2curves::bn256::Fr as F; use halo2curves::ff::{Field, PrimeField}; -#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +#[cfg(not(any( + all(target_arch = "wasm32", target_os = "unknown"), + not(feature = "ezkl") +)))] use ops::lookup::LookupOp; use ops::region::RegionCtx; use rand::rngs::OsRng; @@ -244,7 +247,10 @@ mod matmul_col_overflow { } #[cfg(test)] -#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +#[cfg(all( + feature = "ezkl", + not(all(target_arch = "wasm32", target_os = "unknown")) +))] mod matmul_col_ultra_overflow_double_col { use halo2_proofs::poly::kzg::{ @@ -362,7 +368,10 @@ mod matmul_col_ultra_overflow_double_col { } #[cfg(test)] -#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +#[cfg(all( + feature = "ezkl", + not(all(target_arch = "wasm32", target_os = "unknown")) +))] mod matmul_col_ultra_overflow { use halo2_proofs::poly::kzg::{ @@ -1145,7 +1154,10 @@ mod conv { } #[cfg(test)] -#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +#[cfg(all( + feature = "ezkl", + not(all(target_arch = "wasm32", target_os = "unknown")) +))] mod conv_col_ultra_overflow { use halo2_proofs::poly::{ @@ -1286,7 +1298,10 @@ mod conv_col_ultra_overflow { #[cfg(test)] // not wasm 32 unknown -#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +#[cfg(all( + feature = "ezkl", + not(all(target_arch = "wasm32", target_os = "unknown")) +))] mod conv_relu_col_ultra_overflow { use halo2_proofs::poly::kzg::{ @@ -2449,7 +2464,10 @@ mod relu { } #[cfg(test)] -#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +#[cfg(all( + feature = "ezkl", + not(all(target_arch = "wasm32", target_os = "unknown")) +))] mod lookup_ultra_overflow { use super::*; use halo2_proofs::{ diff --git a/src/commands.rs b/src/commands.rs index baba9b874..841529bde 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -1,4 +1,3 @@ -#[cfg(not(target_arch = "wasm32"))] use alloy::primitives::Address as H160; use clap::{Command, Parser, Subcommand}; use clap_complete::{generate, Generator, Shell}; @@ -17,7 +16,6 @@ use tosubcommand::{ToFlags, ToSubcommand}; use crate::{pfsys::ProofType, Commitments, RunArgs}; use crate::circuit::CheckMode; -#[cfg(not(target_arch = "wasm32"))] use crate::graph::TestDataSource; use crate::pfsys::TranscriptType; @@ -189,7 +187,7 @@ pub enum ContractType { /// Deploys a verifier contrat tailored to the circuit and not reusable Verifier { /// Whether to deploy a reusable verifier. This can reduce state bloat on-chain since you need only deploy a verifying key artifact (vka) for a given circuit which is significantly smaller than the verifier contract (up to 4 times smaller for large circuits) - /// Can also be used as an alternative to aggregation for verifiers that are otherwise too large to fit on-chain. + /// Can also be used as an alternative to aggregation for verifiers that are otherwise too large to fit on-chain. reusable: bool, }, /// Deploys a verifying key artifact that the reusable verifier loads into memory during runtime. Encodes the circuit specific data that was otherwise hardcoded onto the stack. @@ -244,28 +242,24 @@ impl From<&str> for ContractType { } -#[cfg(not(target_arch = "wasm32"))] #[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, PartialOrd)] /// wrapper for H160 to make it easy to parse into flag vals pub struct H160Flag { inner: H160, } -#[cfg(not(target_arch = "wasm32"))] impl From for H160 { fn from(val: H160Flag) -> H160 { val.inner } } -#[cfg(not(target_arch = "wasm32"))] impl ToFlags for H160Flag { fn to_flags(&self) -> Vec { vec![format!("{:#x}", self.inner)] } } -#[cfg(not(target_arch = "wasm32"))] impl From<&str> for H160Flag { fn from(s: &str) -> Self { Self { @@ -461,8 +455,7 @@ pub enum Commands { }, /// Calibrates the proving scale, lookup bits and logrows from a circuit settings file. - #[cfg(not(target_arch = "wasm32"))] - CalibrateSettings { + CalibrateSettings { /// The path to the .json calibration data file. #[arg(short = 'D', long, default_value = DEFAULT_CALIBRATION_FILE, value_hint = clap::ValueHint::FilePath)] data: Option, @@ -512,8 +505,7 @@ pub enum Commands { commitment: Option, }, - #[cfg(not(target_arch = "wasm32"))] - /// Gets an SRS from a circuit settings file. + /// Gets an SRS from a circuit settings file. #[command(name = "get-srs")] GetSrs { /// The path to output the desired srs file, if set to None will save to $EZKL_REPO_PATH/srs @@ -648,8 +640,7 @@ pub enum Commands { #[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION, action = clap::ArgAction::SetTrue)] disable_selector_compression: Option, }, - #[cfg(not(target_arch = "wasm32"))] - /// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information + /// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information #[command(arg_required_else_help = true)] SetupTestEvmData { /// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof) @@ -673,8 +664,7 @@ pub enum Commands { #[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)] output_source: TestDataSource, }, - #[cfg(not(target_arch = "wasm32"))] - /// The Data Attestation Verifier contract stores the account calls to fetch data to feed into ezkl. This call data can be updated by an admin account. This tests that admin account is able to update this call data. + /// The Data Attestation Verifier contract stores the account calls to fetch data to feed into ezkl. This call data can be updated by an admin account. This tests that admin account is able to update this call data. #[command(arg_required_else_help = true)] TestUpdateAccountCalls { /// The path to the verifier contract's address @@ -687,8 +677,7 @@ pub enum Commands { #[arg(short = 'U', long, value_hint = clap::ValueHint::Url)] rpc_url: Option, }, - #[cfg(not(target_arch = "wasm32"))] - /// Swaps the positions in the transcript that correspond to commitments + /// Swaps the positions in the transcript that correspond to commitments SwapProofCommitments { /// The path to the proof file #[arg(short = 'P', long, default_value = DEFAULT_PROOF, value_hint = clap::ValueHint::FilePath)] @@ -698,8 +687,7 @@ pub enum Commands { witness_path: Option, }, - #[cfg(not(target_arch = "wasm32"))] - /// Loads model, data, and creates proof + /// Loads model, data, and creates proof Prove { /// The path to the .json witness file (generated using the gen-witness command) #[arg(short = 'W', long, default_value = DEFAULT_WITNESS, value_hint = clap::ValueHint::FilePath)] @@ -729,8 +717,7 @@ pub enum Commands { #[arg(long, default_value = DEFAULT_CHECKMODE, value_hint = clap::ValueHint::Other)] check_mode: Option, }, - #[cfg(not(target_arch = "wasm32"))] - /// Encodes a proof into evm calldata + /// Encodes a proof into evm calldata #[command(name = "encode-evm-calldata")] EncodeEvmCalldata { /// The path to the proof file (generated using the prove command) @@ -743,8 +730,7 @@ pub enum Commands { #[arg(long, value_hint = clap::ValueHint::Other)] addr_vk: Option, }, - #[cfg(not(target_arch = "wasm32"))] - /// Creates an Evm verifier for a single proof + /// Creates an Evm verifier for a single proof #[command(name = "create-evm-verifier")] CreateEvmVerifier { /// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs @@ -762,12 +748,11 @@ pub enum Commands { /// The path to output the Solidity verifier ABI #[arg(long, default_value = DEFAULT_VERIFIER_ABI, value_hint = clap::ValueHint::FilePath)] abi_path: Option, - /// Whether the to render the verifier as reusable or not. If true, you will need to deploy a VK artifact, passing it as part of the calldata to the verifier. + /// Whether the to render the verifier as reusable or not. If true, you will need to deploy a VK artifact, passing it as part of the calldata to the verifier. #[arg(long, default_value = DEFAULT_RENDER_REUSABLE, action = clap::ArgAction::SetTrue)] reusable: Option, }, - #[cfg(not(target_arch = "wasm32"))] - /// Creates an Evm verifier artifact for a single proof to be used by the reusable verifier + /// Creates an Evm verifier artifact for a single proof to be used by the reusable verifier #[command(name = "create-evm-vka")] CreateEvmVKArtifact { /// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs @@ -786,8 +771,7 @@ pub enum Commands { #[arg(long, default_value = DEFAULT_VK_ABI, value_hint = clap::ValueHint::FilePath)] abi_path: Option, }, - #[cfg(not(target_arch = "wasm32"))] - /// Creates an Evm verifier that attests to on-chain inputs for a single proof + /// Creates an Evm verifier that attests to on-chain inputs for a single proof #[command(name = "create-evm-da")] CreateEvmDataAttestation { /// The path to load circuit settings .json file from (generated using the gen-settings command) @@ -811,8 +795,7 @@ pub enum Commands { witness: Option, }, - #[cfg(not(target_arch = "wasm32"))] - /// Creates an Evm verifier for an aggregate proof + /// Creates an Evm verifier for an aggregate proof #[command(name = "create-evm-verifier-aggr")] CreateEvmVerifierAggr { /// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs @@ -833,7 +816,7 @@ pub enum Commands { // logrows used for aggregation circuit #[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS, value_hint = clap::ValueHint::Other)] logrows: Option, - /// Whether the to render the verifier as reusable or not. If true, you will need to deploy a VK artifact, passing it as part of the calldata to the verifier. + /// Whether the to render the verifier as reusable or not. If true, you will need to deploy a VK artifact, passing it as part of the calldata to the verifier. #[arg(long, default_value = DEFAULT_RENDER_REUSABLE, action = clap::ArgAction::SetTrue)] reusable: Option, }, @@ -876,8 +859,7 @@ pub enum Commands { #[arg(long, default_value = DEFAULT_COMMITMENT, value_hint = clap::ValueHint::Other)] commitment: Option, }, - #[cfg(not(target_arch = "wasm32"))] - /// Deploys an evm contract (verifier, reusable verifier, or vk artifact) that is generated by ezkl + /// Deploys an evm contract (verifier, reusable verifier, or vk artifact) that is generated by ezkl DeployEvm { /// The path to the Solidity code (generated using the create-evm-verifier command) #[arg(long, default_value = DEFAULT_SOL_CODE, value_hint = clap::ValueHint::FilePath)] @@ -898,8 +880,7 @@ pub enum Commands { #[arg(long = "contract-type", short = 'C', default_value = DEFAULT_CONTRACT_DEPLOYMENT_TYPE, value_hint = clap::ValueHint::Other)] contract: ContractType, }, - #[cfg(not(target_arch = "wasm32"))] - /// Deploys an evm verifier that allows for data attestation + /// Deploys an evm verifier that allows for data attestation #[command(name = "deploy-evm-da")] DeployEvmDataAttestation { /// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof) @@ -924,8 +905,7 @@ pub enum Commands { #[arg(short = 'P', long, value_hint = clap::ValueHint::Other)] private_key: Option, }, - #[cfg(not(target_arch = "wasm32"))] - /// Verifies a proof using a local Evm executor, returning accept or reject + /// Verifies a proof using a local Evm executor, returning accept or reject #[command(name = "verify-evm")] VerifyEvm { /// The path to the proof file (generated using the prove command) diff --git a/src/eth.rs b/src/eth.rs index 0b01b4547..b63978fef 100644 --- a/src/eth.rs +++ b/src/eth.rs @@ -1,7 +1,6 @@ use crate::graph::input::{CallsToAccount, FileSourceInner, GraphData}; use crate::graph::modules::POSEIDON_INSTANCES; use crate::graph::DataSource; -#[cfg(not(target_arch = "wasm32"))] use crate::graph::GraphSettings; use crate::pfsys::evm::EvmVerificationError; use crate::pfsys::Snark; @@ -11,8 +10,6 @@ use alloy::core::primitives::Bytes; use alloy::core::primitives::U256; use alloy::dyn_abi::abi::token::{DynSeqToken, PackedSeqToken, WordToken}; use alloy::dyn_abi::abi::TokenSeq; -#[cfg(target_arch = "wasm32")] -use alloy::prelude::Wallet; // use alloy::providers::Middleware; use alloy::json_abi::JsonAbi; use alloy::node_bindings::Anvil; @@ -285,7 +282,6 @@ pub type EthersClient = Arc< pub type ContractFactory = CallBuilder, Arc, ()>; /// Return an instance of Anvil and a client for the given RPC URL. If none is provided, a local client is used. -#[cfg(not(target_arch = "wasm32"))] pub async fn setup_eth_backend( rpc_url: Option<&str>, private_key: Option<&str>, @@ -614,7 +610,6 @@ pub async fn update_account_calls( } /// Verify a proof using a Solidity verifier contract -#[cfg(not(target_arch = "wasm32"))] pub async fn verify_proof_via_solidity( proof: Snark, addr: H160, @@ -716,7 +711,6 @@ pub async fn setup_test_contract, Ethereum>>( /// Verify a proof using a Solidity DataAttestation contract. /// Used for testing purposes. -#[cfg(not(target_arch = "wasm32"))] pub async fn verify_proof_with_data_attestation( proof: Snark, addr_verifier: H160, @@ -829,7 +823,6 @@ pub async fn test_on_chain_data, Ethereum>>( } /// Reads on-chain inputs, returning the raw encoded data returned from making all the calls in on_chain_input_data -#[cfg(not(target_arch = "wasm32"))] pub async fn read_on_chain_inputs, Ethereum>>( client: Arc, address: H160, @@ -863,7 +856,6 @@ pub async fn read_on_chain_inputs, Ethereum>> } /// -#[cfg(not(target_arch = "wasm32"))] pub async fn evm_quantize, Ethereum>>( client: Arc, scales: Vec, @@ -964,7 +956,6 @@ fn get_sol_contract_factory<'a, M: 'static + Provider, Ethereum>, T } /// Compiles a solidity verifier contract and returns the abi, bytecode, and runtime bytecode -#[cfg(not(target_arch = "wasm32"))] pub async fn get_contract_artifacts( sol_code_path: PathBuf, contract_name: &str, diff --git a/src/execute.rs b/src/execute.rs index 97ccb4455..500258876 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -1,18 +1,13 @@ use crate::circuit::region::RegionSettings; use crate::circuit::CheckMode; -#[cfg(not(target_arch = "wasm32"))] use crate::commands::CalibrationTarget; -#[cfg(not(target_arch = "wasm32"))] use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity}; -#[cfg(not(target_arch = "wasm32"))] #[allow(unused_imports)] use crate::eth::{fix_da_sol, get_contract_artifacts, verify_proof_via_solidity}; use crate::graph::input::GraphData; use crate::graph::{GraphCircuit, GraphSettings, GraphWitness, Model}; -#[cfg(not(target_arch = "wasm32"))] use crate::graph::{TestDataSource, TestSources}; use crate::pfsys::evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript}; -#[cfg(not(target_arch = "wasm32"))] use crate::pfsys::{ create_keys, load_pk, load_vk, save_params, save_pk, Snark, StrategyType, TranscriptType, }; @@ -21,11 +16,9 @@ use crate::pfsys::{ }; use crate::pfsys::{save_vk, srs::*}; use crate::tensor::TensorError; -#[cfg(not(target_arch = "wasm32"))] use crate::EZKL_BUF_CAPACITY; use crate::{commands::*, EZKLError}; use crate::{Commitments, RunArgs}; -#[cfg(not(target_arch = "wasm32"))] use colored::Colorize; #[cfg(unix)] use gag::Gag; @@ -45,17 +38,13 @@ use halo2_proofs::poly::kzg::{ }; use halo2_proofs::poly::VerificationStrategy; use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer}; -#[cfg(not(target_arch = "wasm32"))] use halo2_solidity_verifier; use halo2curves::bn256::{Bn256, Fr, G1Affine}; use halo2curves::ff::{FromUniformBytes, WithSmallOrderMulGroup}; use halo2curves::serde::SerdeObject; -#[cfg(not(target_arch = "wasm32"))] use indicatif::{ProgressBar, ProgressStyle}; use instant::Instant; -#[cfg(not(target_arch = "wasm32"))] use itertools::Itertools; -#[cfg(not(target_arch = "wasm32"))] use log::debug; use log::{info, trace, warn}; use serde::de::DeserializeOwned; @@ -65,9 +54,7 @@ use snark_verifier::system::halo2::compile; use snark_verifier::system::halo2::transcript::evm::EvmTranscript; use snark_verifier::system::halo2::Config; use std::fs::File; -#[cfg(not(target_arch = "wasm32"))] use std::io::BufWriter; -#[cfg(not(target_arch = "wasm32"))] use std::io::{Cursor, Write}; use std::path::Path; use std::path::PathBuf; @@ -128,7 +115,6 @@ pub async fn run(command: Commands) -> Result { logrows as u32, commitment.unwrap_or(Commitments::from_str(DEFAULT_COMMITMENT).unwrap()), ), - #[cfg(not(target_arch = "wasm32"))] Commands::GetSrs { srs_path, settings_path, @@ -145,7 +131,6 @@ pub async fn run(command: Commands) -> Result { settings_path.unwrap_or(DEFAULT_SETTINGS.into()), args, ), - #[cfg(not(target_arch = "wasm32"))] Commands::CalibrateSettings { model, settings_path, @@ -188,7 +173,6 @@ pub async fn run(command: Commands) -> Result { model.unwrap_or(DEFAULT_MODEL.into()), witness.unwrap_or(DEFAULT_WITNESS.into()), ), - #[cfg(not(target_arch = "wasm32"))] Commands::CreateEvmVerifier { vk_path, srs_path, @@ -207,7 +191,6 @@ pub async fn run(command: Commands) -> Result { ) .await } - #[cfg(not(target_arch = "wasm32"))] Commands::EncodeEvmCalldata { proof_path, calldata_path, @@ -235,7 +218,6 @@ pub async fn run(command: Commands) -> Result { ) .await } - #[cfg(not(target_arch = "wasm32"))] Commands::CreateEvmDataAttestation { settings_path, sol_code_path, @@ -252,7 +234,6 @@ pub async fn run(command: Commands) -> Result { ) .await } - #[cfg(not(target_arch = "wasm32"))] Commands::CreateEvmVerifierAggr { vk_path, srs_path, @@ -298,7 +279,6 @@ pub async fn run(command: Commands) -> Result { disable_selector_compression .unwrap_or(DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap()), ), - #[cfg(not(target_arch = "wasm32"))] Commands::SetupTestEvmData { data, compiled_circuit, @@ -317,13 +297,11 @@ pub async fn run(command: Commands) -> Result { ) .await } - #[cfg(not(target_arch = "wasm32"))] Commands::TestUpdateAccountCalls { addr, data, rpc_url, } => test_update_account_calls(addr, data.unwrap_or(DEFAULT_DATA.into()), rpc_url).await, - #[cfg(not(target_arch = "wasm32"))] Commands::SwapProofCommitments { proof_path, witness_path, @@ -333,7 +311,6 @@ pub async fn run(command: Commands) -> Result { ) .map(|e| serde_json::to_string(&e).unwrap()), - #[cfg(not(target_arch = "wasm32"))] Commands::Prove { witness, compiled_circuit, @@ -433,7 +410,6 @@ pub async fn run(command: Commands) -> Result { commitment.into(), ) .map(|e| serde_json::to_string(&e).unwrap()), - #[cfg(not(target_arch = "wasm32"))] Commands::DeployEvm { sol_code_path, rpc_url, @@ -452,7 +428,6 @@ pub async fn run(command: Commands) -> Result { ) .await } - #[cfg(not(target_arch = "wasm32"))] Commands::DeployEvmDataAttestation { data, settings_path, @@ -473,7 +448,6 @@ pub async fn run(command: Commands) -> Result { ) .await } - #[cfg(not(target_arch = "wasm32"))] Commands::VerifyEvm { proof_path, addr_verifier, @@ -593,7 +567,6 @@ pub(crate) fn gen_srs_cmd( Ok(String::new()) } -#[cfg(not(target_arch = "wasm32"))] async fn fetch_srs(uri: &str) -> Result, EZKLError> { let pb = { let pb = init_spinner(); @@ -613,7 +586,6 @@ async fn fetch_srs(uri: &str) -> Result, EZKLError> { Ok(std::mem::take(&mut buf)) } -#[cfg(not(target_arch = "wasm32"))] pub(crate) fn get_file_hash(path: &PathBuf) -> Result { use std::io::Read; let file = std::fs::File::open(path)?; @@ -632,7 +604,6 @@ pub(crate) fn get_file_hash(path: &PathBuf) -> Result { Ok(hash) } -#[cfg(not(target_arch = "wasm32"))] fn check_srs_hash( logrows: u32, srs_path: Option, @@ -658,7 +629,6 @@ fn check_srs_hash( Ok(hash) } -#[cfg(not(target_arch = "wasm32"))] pub(crate) async fn get_srs_cmd( srs_path: Option, settings_path: Option, @@ -701,13 +671,10 @@ pub(crate) async fn get_srs_cmd( let srs_uri = format!("{}{}", PUBLIC_SRS_URL, k); let mut reader = Cursor::new(fetch_srs(&srs_uri).await?); // check the SRS - #[cfg(not(target_arch = "wasm32"))] - let pb = init_spinner(); - #[cfg(not(target_arch = "wasm32"))] - pb.set_message("Validating SRS (this may take a while) ..."); + let pb = init_spinner(); + pb.set_message("Validating SRS (this may take a while) ..."); let params = ParamsKZG::::read(&mut reader)?; - #[cfg(not(target_arch = "wasm32"))] - pb.finish_with_message("SRS validated."); + pb.finish_with_message("SRS validated."); info!("Saving SRS to disk..."); let mut file = std::fs::File::create(get_srs_path(k, srs_path.clone(), commitment))?; @@ -760,9 +727,8 @@ pub(crate) async fn gen_witness( None }; - #[cfg(not(target_arch = "wasm32"))] - let mut input = circuit.load_graph_input(&data).await?; - #[cfg(target_arch = "wasm32")] + let mut input = circuit.load_graph_input(&data).await?; + #[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))] let mut input = circuit.load_graph_input(&data)?; // if any of the settings have kzg visibility then we need to load the srs @@ -858,7 +824,6 @@ pub(crate) fn gen_circuit_settings( } // not for wasm targets -#[cfg(not(target_arch = "wasm32"))] pub(crate) fn init_spinner() -> ProgressBar { let pb = indicatif::ProgressBar::new_spinner(); pb.set_draw_target(indicatif::ProgressDrawTarget::stdout()); @@ -880,7 +845,6 @@ pub(crate) fn init_spinner() -> ProgressBar { } // not for wasm targets -#[cfg(not(target_arch = "wasm32"))] pub(crate) fn init_bar(len: u64) -> ProgressBar { let pb = ProgressBar::new(len); pb.set_draw_target(indicatif::ProgressDrawTarget::stdout()); @@ -894,7 +858,6 @@ pub(crate) fn init_bar(len: u64) -> ProgressBar { pb } -#[cfg(not(target_arch = "wasm32"))] use colored_json::ToColoredJson; #[derive(Debug, Clone, Tabled)] @@ -994,7 +957,6 @@ impl AccuracyResults { } /// Calibrate the circuit parameters to a given a dataset -#[cfg(not(target_arch = "wasm32"))] #[allow(trivial_casts)] #[allow(clippy::too_many_arguments)] pub(crate) async fn calibrate( @@ -1127,12 +1089,12 @@ pub(crate) async fn calibrate( }; // if unix get a gag - #[cfg(unix)] + #[cfg(all(not(not(feature = "ezkl")), unix))] let _r = match Gag::stdout() { Ok(g) => Some(g), _ => None, }; - #[cfg(unix)] + #[cfg(all(not(not(feature = "ezkl")), unix))] let _g = match Gag::stderr() { Ok(g) => Some(g), _ => None, @@ -1191,9 +1153,9 @@ pub(crate) async fn calibrate( } // drop the gag - #[cfg(unix)] + #[cfg(all(not(not(feature = "ezkl")), unix))] drop(_r); - #[cfg(unix)] + #[cfg(all(not(not(feature = "ezkl")), unix))] drop(_g); let result = forward_pass_res.get(&key).ok_or("key not found")?; @@ -1408,7 +1370,6 @@ pub(crate) fn mock( Ok(String::new()) } -#[cfg(not(target_arch = "wasm32"))] pub(crate) async fn create_evm_verifier( vk_path: PathBuf, srs_path: Option, @@ -1453,7 +1414,6 @@ pub(crate) async fn create_evm_verifier( Ok(String::new()) } -#[cfg(not(target_arch = "wasm32"))] pub(crate) async fn create_evm_vka( vk_path: PathBuf, srs_path: Option, @@ -1494,7 +1454,6 @@ pub(crate) async fn create_evm_vka( Ok(String::new()) } -#[cfg(not(target_arch = "wasm32"))] pub(crate) async fn create_evm_data_attestation( settings_path: PathBuf, sol_code_path: PathBuf, @@ -1571,7 +1530,6 @@ pub(crate) async fn create_evm_data_attestation( Ok(String::new()) } -#[cfg(not(target_arch = "wasm32"))] pub(crate) async fn deploy_da_evm( data: PathBuf, settings_path: PathBuf, @@ -1598,7 +1556,6 @@ pub(crate) async fn deploy_da_evm( Ok(String::new()) } -#[cfg(not(target_arch = "wasm32"))] pub(crate) async fn deploy_evm( sol_code_path: PathBuf, rpc_url: Option, @@ -1654,7 +1611,6 @@ pub(crate) fn encode_evm_calldata( Ok(encoded) } -#[cfg(not(target_arch = "wasm32"))] pub(crate) async fn verify_evm( proof_path: PathBuf, addr_verifier: H160Flag, @@ -1694,7 +1650,6 @@ pub(crate) async fn verify_evm( Ok(String::new()) } -#[cfg(not(target_arch = "wasm32"))] pub(crate) async fn create_evm_aggregate_verifier( vk_path: PathBuf, srs_path: Option, @@ -1818,7 +1773,6 @@ pub(crate) fn setup( Ok(String::new()) } -#[cfg(not(target_arch = "wasm32"))] pub(crate) async fn setup_test_evm_witness( data_path: PathBuf, compiled_circuit_path: PathBuf, @@ -1854,9 +1808,7 @@ pub(crate) async fn setup_test_evm_witness( Ok(String::new()) } -#[cfg(not(target_arch = "wasm32"))] use crate::pfsys::ProofType; -#[cfg(not(target_arch = "wasm32"))] pub(crate) async fn test_update_account_calls( addr: H160Flag, data: PathBuf, @@ -1869,7 +1821,6 @@ pub(crate) async fn test_update_account_calls( Ok(String::new()) } -#[cfg(not(target_arch = "wasm32"))] #[allow(clippy::too_many_arguments)] pub(crate) fn prove( data_path: PathBuf, @@ -2067,8 +2018,7 @@ pub(crate) fn mock_aggregate( } } // proof aggregation - #[cfg(not(target_arch = "wasm32"))] - let pb = { + let pb = { let pb = init_spinner(); pb.set_message("Aggregating (may take a while)..."); pb @@ -2079,8 +2029,7 @@ pub(crate) fn mock_aggregate( let prover = halo2_proofs::dev::MockProver::run(logrows, &circuit, vec![circuit.instances()]) .map_err(|e| ExecutionError::MockProverError(e.to_string()))?; prover.verify().map_err(ExecutionError::VerifyError)?; - #[cfg(not(target_arch = "wasm32"))] - pb.finish_with_message("Done."); + pb.finish_with_message("Done."); Ok(String::new()) } @@ -2174,8 +2123,7 @@ pub(crate) fn aggregate( } // proof aggregation - #[cfg(not(target_arch = "wasm32"))] - let pb = { + let pb = { let pb = init_spinner(); pb.set_message("Aggregating (may take a while)..."); pb @@ -2324,8 +2272,7 @@ pub(crate) fn aggregate( ); snark.save(&proof_path)?; - #[cfg(not(target_arch = "wasm32"))] - pb.finish_with_message("Done."); + pb.finish_with_message("Done."); Ok(snark) } diff --git a/src/graph/errors.rs b/src/graph/errors.rs index a7652fc0b..2b7e1efd1 100644 --- a/src/graph/errors.rs +++ b/src/graph/errors.rs @@ -48,7 +48,10 @@ pub enum GraphError { #[error("failed to ser/deser model: {0}")] ModelSerialize(#[from] bincode::Error), /// Tract error - #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] + #[cfg(all( + feature = "ezkl", + not(all(target_arch = "wasm32", target_os = "unknown")) + ))] #[error("[tract] {0}")] TractError(#[from] tract_onnx::prelude::TractError), /// Packing exponent is too large @@ -85,11 +88,17 @@ pub enum GraphError { #[error("unknown dimension batch_size in model inputs, set batch_size in variables")] MissingBatchSize, /// Tokio postgres error - #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] + #[cfg(all( + feature = "ezkl", + not(all(target_arch = "wasm32", target_os = "unknown")) + ))] #[error("[tokio postgres] {0}")] TokioPostgresError(#[from] tokio_postgres::Error), /// Eth error - #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] + #[cfg(all( + feature = "ezkl", + not(all(target_arch = "wasm32", target_os = "unknown")) + ))] #[error("[eth] {0}")] EthError(#[from] crate::eth::EthError), /// Json error diff --git a/src/graph/input.rs b/src/graph/input.rs index e3f324c64..dbef47ecc 100644 --- a/src/graph/input.rs +++ b/src/graph/input.rs @@ -2,9 +2,9 @@ use super::errors::GraphError; use super::quantize_float; use crate::circuit::InputType; use crate::fieldutils::integer_rep_to_felt; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use crate::graph::postgres::Client; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use crate::tensor::Tensor; use crate::EZKL_BUF_CAPACITY; use halo2curves::bn256::Fr as Fp; @@ -20,12 +20,12 @@ use std::io::BufReader; use std::io::BufWriter; use std::io::Read; use std::panic::UnwindSafe; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use tract_onnx::tract_core::{ tract_data::{prelude::Tensor as TractTensor, TVec}, value::TValue, }; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use tract_onnx::tract_hir::tract_num_traits::ToPrimitive; type Decimals = u8; @@ -171,7 +171,7 @@ impl OnChainSource { } } -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] /// Inner elements of inputs/outputs coming from postgres DB #[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)] pub struct PostgresSource { @@ -189,7 +189,7 @@ pub struct PostgresSource { pub port: String, } -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] impl PostgresSource { /// Create a new PostgresSource pub fn new( @@ -268,7 +268,7 @@ impl PostgresSource { } impl OnChainSource { - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] /// Create dummy local on-chain data to test the OnChain data source pub async fn test_from_file_data( data: &FileSource, @@ -359,7 +359,7 @@ pub enum DataSource { /// On-chain data source. The first element is the calls to the account, and the second is the RPC url. OnChain(OnChainSource), /// Postgres DB - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] DB(PostgresSource), } @@ -419,7 +419,7 @@ impl<'de> Deserialize<'de> for DataSource { if let Ok(t) = second_try { return Ok(DataSource::OnChain(t)); } - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] { let third_try: Result = serde_json::from_str(this_json.get()); if let Ok(t) = third_try { @@ -445,7 +445,7 @@ impl UnwindSafe for GraphData {} impl GraphData { // not wasm - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] /// Convert the input data to tract data pub fn to_tract_data( &self, @@ -530,7 +530,7 @@ impl GraphData { "on-chain data cannot be split into batches".to_string(), )) } - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] GraphData { input_data: DataSource::DB(data), output_data: _, diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 667744d65..acf9af412 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -7,7 +7,7 @@ pub mod modules; /// Inner elements of a computational graph that represent a single operation / constraints. pub mod node; /// postgres helper functions -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] pub mod postgres; /// Helper functions pub mod utilities; @@ -17,18 +17,19 @@ pub mod vars; /// errors for the graph pub mod errors; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use colored_json::ToColoredJson; -#[cfg(unix)] +#[cfg(all(not(not(feature = "ezkl")), unix))] use gag::Gag; use halo2_proofs::plonk::VerifyingKey; use halo2_proofs::poly::commitment::CommitmentScheme; pub use input::DataSource; use itertools::Itertools; +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use tosubcommand::ToFlags; use self::errors::GraphError; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use self::input::OnChainSource; use self::input::{FileSource, GraphData}; use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes}; @@ -48,7 +49,7 @@ use halo2_proofs::{ }; use halo2curves::bn256::{self, Fr as Fp, G1Affine}; use halo2curves::ff::{Field, PrimeField}; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use lazy_static::lazy_static; use log::{debug, error, trace, warn}; use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; @@ -78,7 +79,7 @@ pub const MAX_NUM_LOOKUP_COLS: usize = 12; pub const MAX_LOOKUP_ABS: IntegerRep = (MAX_NUM_LOOKUP_COLS as IntegerRep) * 2_i128.pow(MAX_PUBLIC_SRS); -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] lazy_static! { /// Max circuit area pub static ref EZKL_MAX_CIRCUIT_AREA: Option = @@ -89,7 +90,7 @@ lazy_static! { }; } -#[cfg(target_arch = "wasm32")] +#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))] const EZKL_MAX_CIRCUIT_AREA: Option = None; /// @@ -384,7 +385,7 @@ fn insert_poseidon_hash_pydict(pydict: &PyDict, poseidon_hash: &Vec) -> Resu #[cfg(feature = "python-bindings")] fn insert_polycommit_pydict(pydict: &PyDict, commits: &Vec>) -> Result<(), PyErr> { - use crate::python::PyG1Affine; + use crate::bindings::python::PyG1Affine; let poseidon_hash: Vec> = commits .iter() .map(|c| c.iter().map(|x| PyG1Affine::from(*x)).collect()) @@ -697,6 +698,7 @@ impl std::fmt::Display for TestDataSource { } } +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] impl ToFlags for TestDataSource {} impl From for TestDataSource { @@ -885,7 +887,7 @@ impl GraphCircuit { public_inputs.processed_outputs = elements.processed_outputs.clone(); } - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] debug!( "rescaled and processed public inputs: {}", serde_json::to_string(&public_inputs)?.to_colored_json_auto()? @@ -895,7 +897,7 @@ impl GraphCircuit { } /// - #[cfg(target_arch = "wasm32")] + #[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))] pub fn load_graph_input(&mut self, data: &GraphData) -> Result>, GraphError> { let shapes = self.model().graph.input_shapes()?; let scales = self.model().graph.get_input_scales(); @@ -922,7 +924,7 @@ impl GraphCircuit { } /// - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] pub async fn load_graph_input( &mut self, data: &GraphData, @@ -936,7 +938,7 @@ impl GraphCircuit { .await } - #[cfg(target_arch = "wasm32")] + #[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))] /// Process the data source for the model fn process_data_source( &mut self, @@ -953,7 +955,7 @@ impl GraphCircuit { } } - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] /// Process the data source for the model async fn process_data_source( &mut self, @@ -983,7 +985,7 @@ impl GraphCircuit { } /// Prepare on chain test data - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] pub async fn load_on_chain_data( &mut self, source: OnChainSource, @@ -1202,12 +1204,12 @@ impl GraphCircuit { settings.required_range_checks = vec![(0, max_range_size)]; let mut cs = ConstraintSystem::default(); // if unix get a gag - #[cfg(unix)] + #[cfg(all(not(not(feature = "ezkl")), unix))] let _r = match Gag::stdout() { Ok(g) => Some(g), _ => None, }; - #[cfg(unix)] + #[cfg(all(not(not(feature = "ezkl")), unix))] let _g = match Gag::stderr() { Ok(g) => Some(g), _ => None, @@ -1216,9 +1218,9 @@ impl GraphCircuit { Self::configure_with_params(&mut cs, settings); // drop the gag - #[cfg(unix)] + #[cfg(all(not(not(feature = "ezkl")), unix))] drop(_r); - #[cfg(unix)] + #[cfg(all(not(not(feature = "ezkl")), unix))] drop(_g); #[cfg(feature = "mv-lookup")] @@ -1347,7 +1349,7 @@ impl GraphCircuit { visibility, ); - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] log::trace!( "witness: \n {}", &witness.as_json()?.to_colored_json_auto()? @@ -1357,7 +1359,7 @@ impl GraphCircuit { } /// Create a new circuit from a set of input data and [RunArgs]. - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] pub fn from_run_args( run_args: &RunArgs, model_path: &std::path::Path, @@ -1367,7 +1369,7 @@ impl GraphCircuit { } /// Create a new circuit from a set of input data and [GraphSettings]. - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] pub fn from_settings( params: &GraphSettings, model_path: &std::path::Path, @@ -1382,7 +1384,7 @@ impl GraphCircuit { } /// - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] pub async fn populate_on_chain_test_data( &mut self, data: &mut GraphData, @@ -1475,7 +1477,7 @@ impl CircuitSize { } } - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] /// Export the ezkl configuration as json pub fn as_json(&self) -> Result { let serialized = match serde_json::to_string(&self) { @@ -1563,7 +1565,7 @@ impl Circuit for GraphCircuit { let circuit_size = CircuitSize::from_cs(cs, params.run_args.logrows); - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] debug!( "circuit size: \n {}", circuit_size diff --git a/src/graph/model.rs b/src/graph/model.rs index e53c4c098..809fefcfd 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -21,9 +21,9 @@ use crate::{ }; use halo2curves::bn256::Fr as Fp; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use super::input::GraphData; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use colored::Colorize; use halo2_proofs::{ circuit::{Layouter, Value}, @@ -36,29 +36,29 @@ use log::{debug, info, trace}; use serde::Deserialize; use serde::Serialize; use std::collections::BTreeMap; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use std::collections::HashMap; use std::collections::HashSet; use std::fs; use std::io::Read; use std::path::PathBuf; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use tabled::Table; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use tract_onnx; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use tract_onnx::prelude::{ Framework, Graph, InferenceFact, InferenceModelExt, SymbolValues, TypedFact, TypedOp, }; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use tract_onnx::tract_core::internal::DatumType; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use tract_onnx::tract_hir::ops::scan::Scan; use unzip_n::unzip_n; unzip_n!(pub 3); -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] type TractResult = (Graph>, SymbolValues); /// The result of a forward pass. #[derive(Clone, Debug)] @@ -470,7 +470,7 @@ impl Model { /// # Arguments /// * `reader` - A reader for an Onnx file. /// * `run_args` - [RunArgs] - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] pub fn new(reader: &mut dyn std::io::Read, run_args: &RunArgs) -> Result { let visibility = VarVisibility::from_args(run_args)?; @@ -517,7 +517,7 @@ impl Model { check_mode: CheckMode, ) -> Result { let instance_shapes = self.instance_shapes()?; - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] debug!( "{} {} {}", "model has".blue(), @@ -574,13 +574,13 @@ impl Model { version: env!("CARGO_PKG_VERSION").to_string(), num_blinding_factors: None, // unix time timestamp - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] timestamp: Some( instant::SystemTime::now() .duration_since(instant::SystemTime::UNIX_EPOCH)? .as_millis(), ), - #[cfg(target_arch = "wasm32")] + #[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))] timestamp: None, }) } @@ -609,7 +609,7 @@ impl Model { /// * `reader` - A reader for an Onnx file. /// * `scale` - The scale to use for quantization. /// * `public_params` - Whether to make the params public. - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] fn load_onnx_using_tract( reader: &mut dyn std::io::Read, run_args: &RunArgs, @@ -664,7 +664,7 @@ impl Model { /// * `reader` - A reader for an Onnx file. /// * `scale` - The scale to use for quantization. /// * `public_params` - Whether to make the params public. - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] fn load_onnx_model( reader: &mut dyn std::io::Read, run_args: &RunArgs, @@ -700,7 +700,7 @@ impl Model { } /// Formats nodes (including subgraphs) into tables ! - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] pub fn table_nodes(&self) -> String { let mut node_accumulator = vec![]; let mut string = String::new(); @@ -742,7 +742,7 @@ impl Model { /// * `visibility` - Which inputs to the model are public and private (params, inputs, outputs) using [VarVisibility]. /// * `input_scales` - The scales of the model's inputs. - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] pub fn nodes_from_graph( graph: &Graph>, run_args: &RunArgs, @@ -931,7 +931,7 @@ impl Model { Ok(nodes) } - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] /// Removes all nodes that are consts with 0 uses fn remove_unused_nodes(nodes: &mut BTreeMap) { // remove all nodes that are consts with 0 uses now @@ -950,7 +950,7 @@ impl Model { }); } - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] /// Run tract onnx model on sample data ! pub fn run_onnx_predictions( run_args: &RunArgs, @@ -991,7 +991,7 @@ impl Model { /// Creates a `Model` from parsed run_args /// # Arguments /// * `params` - A [GraphSettings] struct holding parsed CLI arguments. - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] pub fn from_run_args(run_args: &RunArgs, model: &std::path::Path) -> Result { let mut file = std::fs::File::open(model).map_err(|e| { GraphError::ReadWriteFileError(model.display().to_string(), e.to_string()) @@ -1166,7 +1166,7 @@ impl Model { })?; } // Then number of columns in the circuits - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] thread_safe_region.debug_report(); *constants = thread_safe_region.assigned_constants().clone(); @@ -1197,7 +1197,7 @@ impl Model { for (idx, node) in self.graph.nodes.iter() { debug!("laying out {}: {}", idx, node.as_str(),); // Then number of columns in the circuits - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] region.debug_report(); debug!("input indices: {:?}", node.inputs()); debug!("output scales: {:?}", node.out_scales()); @@ -1451,7 +1451,7 @@ impl Model { trace!("dummy model layout took: {:?}", duration); // Then number of columns in the circuits - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] region.debug_report(); let outputs = outputs diff --git a/src/graph/node.rs b/src/graph/node.rs index a46654752..2290c7588 100644 --- a/src/graph/node.rs +++ b/src/graph/node.rs @@ -1,9 +1,9 @@ use super::scale_to_multiplier; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use super::utilities::node_output_shapes; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use super::VarScales; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use super::Visibility; use crate::circuit::hybrid::HybridOp; use crate::circuit::lookup::LookupOp; @@ -13,29 +13,29 @@ use crate::circuit::Constant; use crate::circuit::Input; use crate::circuit::Op; use crate::circuit::Unknown; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use crate::graph::errors::GraphError; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use crate::graph::new_op_from_onnx; use crate::tensor::TensorError; use halo2curves::bn256::Fr as Fp; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use log::trace; use serde::Deserialize; use serde::Serialize; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use std::collections::BTreeMap; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use std::fmt; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use tabled::Tabled; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use tract_onnx::{ self, prelude::{Node as OnnxNode, SymbolValues, TypedFact, TypedOp}, }; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] fn display_vector(v: &Vec) -> String { if !v.is_empty() { format!("{:?}", v) @@ -44,7 +44,7 @@ fn display_vector(v: &Vec) -> String { } } -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] fn display_opkind(v: &SupportedOp) -> String { v.as_string() } @@ -303,7 +303,7 @@ impl SupportedOp { } } - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] fn homogenous_rescale( &self, in_scales: Vec, @@ -441,7 +441,7 @@ pub struct Node { pub num_uses: usize, } -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] impl Tabled for Node { const LENGTH: usize = 6; @@ -481,7 +481,7 @@ impl Node { /// * `other_nodes` - [BTreeMap] of other previously initialized [Node]s in the computational graph. /// * `public_params` - flag if parameters of model are public /// * `idx` - The node's unique identifier. - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] #[allow(clippy::too_many_arguments)] pub fn new( node: OnnxNode>, @@ -625,7 +625,7 @@ impl Node { } } -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] fn rescale_const_with_single_use( constant: &mut Constant, in_scales: Vec, diff --git a/src/graph/postgres.rs b/src/graph/postgres.rs index e5d59e65a..aef4d1f53 100644 --- a/src/graph/postgres.rs +++ b/src/graph/postgres.rs @@ -1,7 +1,7 @@ use log::{debug, error, info}; use std::fmt::Debug; use std::net::IpAddr; -#[cfg(unix)] +#[cfg(all(not(not(feature = "ezkl")), unix))] use std::path::Path; use std::str::FromStr; use std::sync::Arc; @@ -150,7 +150,7 @@ impl Config { /// Adds a Unix socket host to the configuration. /// /// Unlike `host`, this method allows non-UTF8 paths. - #[cfg(unix)] + #[cfg(all(not(not(feature = "ezkl")), unix))] pub fn host_path(&mut self, host: T) -> &mut Config where T: AsRef, diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index 8bb2eb193..8c77269f7 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -1,12 +1,12 @@ use super::errors::GraphError; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use super::VarScales; use super::{Rescaled, SupportedOp, Visibility}; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use crate::circuit::hybrid::HybridOp; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use crate::circuit::lookup::LookupOp; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use crate::circuit::poly::PolyOp; use crate::circuit::Op; use crate::fieldutils::IntegerRep; @@ -14,13 +14,13 @@ use crate::tensor::{Tensor, TensorError, TensorType}; use halo2curves::bn256::Fr as Fp; use halo2curves::ff::PrimeField; use itertools::Itertools; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use log::{debug, warn}; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use std::sync::Arc; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp}; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use tract_onnx::tract_core::ops::{ array::{ Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd, @@ -33,7 +33,7 @@ use tract_onnx::tract_core::ops::{ nn::{LeakyRelu, Reduce, Softmax}, Downsample, }; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use tract_onnx::tract_hir::{ internal::DimLike, ops::array::{Pad, PadMode, TypedConcat}, @@ -90,7 +90,7 @@ pub fn multiplier_to_scale(mult: f64) -> crate::Scale { mult.log2().round() as crate::Scale } -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] /// extract padding from a onnx node. pub fn extract_padding( pool_spec: &PoolSpec, @@ -109,7 +109,7 @@ pub fn extract_padding( Ok(padding) } -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] /// Extracts the strides from a onnx node. pub fn extract_strides(pool_spec: &PoolSpec) -> Result, GraphError> { Ok(pool_spec @@ -120,7 +120,7 @@ pub fn extract_strides(pool_spec: &PoolSpec) -> Result, GraphError> { } /// Gets the shape of a onnx node's outlets. -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] pub fn node_output_shapes( node: &OnnxNode>, symbol_values: &SymbolValues, @@ -135,9 +135,9 @@ pub fn node_output_shapes( } Ok(shapes) } -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use tract_onnx::prelude::SymbolValues; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] /// Extracts the raw values from a tensor. pub fn extract_tensor_value( input: Arc, @@ -246,7 +246,7 @@ pub fn extract_tensor_value( Ok(const_value) } -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] fn load_op( op: &dyn tract_onnx::prelude::Op, idx: usize, @@ -270,7 +270,7 @@ fn load_op( /// * `param_visibility` - [Visibility] of the node. /// * `node` - the [OnnxNode] to be matched. /// * `inputs` - the node's inputs. -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] pub fn new_op_from_onnx( idx: usize, scales: &VarScales, diff --git a/src/graph/vars.rs b/src/graph/vars.rs index 14d3bf476..292593a81 100644 --- a/src/graph/vars.rs +++ b/src/graph/vars.rs @@ -14,6 +14,7 @@ use pyo3::{ }; use serde::{Deserialize, Serialize}; +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use tosubcommand::ToFlags; use self::errors::GraphError; @@ -64,6 +65,7 @@ impl Display for Visibility { } } +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] impl ToFlags for Visibility { fn to_flags(&self) -> Vec { vec![format!("{}", self)] diff --git a/src/lib.rs b/src/lib.rs index d5b4e964f..ffa770699 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,12 +30,16 @@ //! /// Error type +// #[cfg_attr(not(feature = "ezkl"), derive(uniffi::Error))] #[derive(thiserror::Error, Debug)] #[allow(missing_docs)] pub enum EZKLError { #[error("[aggregation] {0}")] AggregationError(#[from] pfsys::evm::aggregation_kzg::AggregationError), - #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] + #[cfg(all( + feature = "ezkl", + not(all(target_arch = "wasm32", target_os = "unknown")) + ))] #[error("[eth] {0}")] EthError(#[from] eth::EthError), #[error("[graph] {0}")] @@ -54,7 +58,10 @@ pub enum EZKLError { JsonError(#[from] serde_json::Error), #[error("[utf8] {0}")] Utf8Error(#[from] std::str::Utf8Error), - #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] + #[cfg(all( + feature = "ezkl", + not(all(target_arch = "wasm32", target_os = "unknown")) + ))] #[error("[reqwest] {0}")] ReqwestError(#[from] reqwest::Error), #[error("[fmt] {0}")] @@ -63,7 +70,10 @@ pub enum EZKLError { Halo2Error(#[from] halo2_proofs::plonk::Error), #[error("[Uncategorized] {0}")] UncategorizedError(String), - #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] + #[cfg(all( + feature = "ezkl", + not(all(target_arch = "wasm32", target_os = "unknown")) + ))] #[error("[execute] {0}")] ExecutionError(#[from] execute::ExecutionError), #[error("[srs] {0}")] @@ -85,7 +95,9 @@ impl From for EZKLError { use std::str::FromStr; use circuit::{table::Range, CheckMode, Tolerance}; +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use clap::Args; +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use fieldutils::IntegerRep; use graph::Visibility; use halo2_proofs::poly::{ @@ -93,52 +105,62 @@ use halo2_proofs::poly::{ }; use halo2curves::bn256::{Bn256, G1Affine}; use serde::{Deserialize, Serialize}; +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use tosubcommand::ToFlags; +/// Bindings managment +#[cfg(any( + feature = "ios-bindings", + all(target_arch = "wasm32", target_os = "unknown"), + feature = "python-bindings" +))] +pub mod bindings; /// Methods for configuring tensor operations and assigning values to them in a Halo2 circuit. pub mod circuit; /// CLI commands. -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] pub mod commands; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] // abigen doesn't generate docs for this module #[allow(missing_docs)] /// Utility functions for contracts pub mod eth; /// Command execution /// -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] pub mod execute; /// Utilities for converting from Halo2 Field types to integers (and vice-versa). pub mod fieldutils; /// Methods for loading onnx format models and automatically laying them out in /// a Halo2 circuit. -#[cfg(feature = "onnx")] +#[cfg(any(feature = "onnx", not(feature = "ezkl")))] pub mod graph; /// beautiful logging -#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +#[cfg(all( + feature = "ezkl", + not(all(target_arch = "wasm32", target_os = "unknown")) +))] pub mod logger; /// Tools for proofs and verification used by cli pub mod pfsys; -/// Python bindings -#[cfg(feature = "python-bindings")] -pub mod python; /// srs sha hashes -#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +#[cfg(all( + feature = "ezkl", + not(all(target_arch = "wasm32", target_os = "unknown")) +))] pub mod srs_sha; /// An implementation of multi-dimensional tensors. pub mod tensor; -/// wasm prover and verifier -#[cfg(all(target_arch = "wasm32", target_os = "unknown"))] -pub mod wasm; +#[cfg(feature = "ios-bindings")] +uniffi::setup_scaffolding!(); -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use lazy_static::lazy_static; /// The denominator in the fixed point representation used when quantizing inputs pub type Scale = i32; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] // Buf writer capacity lazy_static! { /// The capacity of the buffer used for writing to disk @@ -153,10 +175,10 @@ lazy_static! { } -#[cfg(target_arch = "wasm32")] +#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))] const EZKL_KEY_FORMAT: &str = "raw-bytes"; -#[cfg(target_arch = "wasm32")] +#[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))] const EZKL_BUF_CAPACITY: &usize = &8000; #[derive( @@ -209,6 +231,7 @@ impl std::fmt::Display for Commitments { } } +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] impl ToFlags for Commitments { /// Convert the struct to a subcommand string fn to_flags(&self) -> Vec { @@ -231,57 +254,67 @@ impl From for Commitments { } /// Parameters specific to a proving run -#[derive(Debug, Args, Deserialize, Serialize, Clone, PartialEq, PartialOrd, ToFlags)] +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, PartialOrd)] +#[cfg_attr( + all(feature = "ezkl", not(target_arch = "wasm32")), + derive(Args, ToFlags) +)] pub struct RunArgs { /// The tolerance for error on model outputs - #[arg(short = 'T', long, default_value = "0", value_hint = clap::ValueHint::Other)] + #[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'T', long, default_value = "0", value_hint = clap::ValueHint::Other))] pub tolerance: Tolerance, /// The denominator in the fixed point representation used when quantizing inputs - #[arg(short = 'S', long, default_value = "7", value_hint = clap::ValueHint::Other)] + #[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'S', long, default_value = "7", value_hint = clap::ValueHint::Other))] pub input_scale: Scale, /// The denominator in the fixed point representation used when quantizing parameters - #[arg(long, default_value = "7", value_hint = clap::ValueHint::Other)] + #[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "7", value_hint = clap::ValueHint::Other))] pub param_scale: Scale, /// if the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution) - #[arg(long, default_value = "1", value_hint = clap::ValueHint::Other)] + #[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "1", value_hint = clap::ValueHint::Other))] pub scale_rebase_multiplier: u32, /// The min and max elements in the lookup table input column - #[arg(short = 'B', long, value_parser = parse_key_val::, default_value = "-32768->32768")] + #[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'B', long, value_parser = parse_key_val::, default_value = "-32768->32768"))] pub lookup_range: Range, /// The log_2 number of rows - #[arg(short = 'K', long, default_value = "17", value_hint = clap::ValueHint::Other)] + #[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'K', long, default_value = "17", value_hint = clap::ValueHint::Other))] pub logrows: u32, /// The log_2 number of rows - #[arg(short = 'N', long, default_value = "2", value_hint = clap::ValueHint::Other)] + #[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'N', long, default_value = "2", value_hint = clap::ValueHint::Other))] pub num_inner_cols: usize, /// Hand-written parser for graph variables, eg. batch_size=1 - #[arg(short = 'V', long, value_parser = parse_key_val::, default_value = "batch_size->1", value_delimiter = ',', value_hint = clap::ValueHint::Other)] + #[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(short = 'V', long, value_parser = parse_key_val::, default_value = "batch_size->1", value_delimiter = ',', value_hint = clap::ValueHint::Other))] pub variables: Vec<(String, usize)>, /// Flags whether inputs are public, private, fixed, hashed, polycommit - #[arg(long, default_value = "private", value_hint = clap::ValueHint::Other)] + #[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "private", value_hint = clap::ValueHint::Other))] pub input_visibility: Visibility, /// Flags whether outputs are public, private, fixed, hashed, polycommit - #[arg(long, default_value = "public", value_hint = clap::ValueHint::Other)] + #[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "public", value_hint = clap::ValueHint::Other))] pub output_visibility: Visibility, /// Flags whether params are fixed, private, hashed, polycommit - #[arg(long, default_value = "private", value_hint = clap::ValueHint::Other)] + #[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "private", value_hint = clap::ValueHint::Other))] pub param_visibility: Visibility, - #[arg(long, default_value = "false")] + #[cfg_attr( + all(feature = "ezkl", not(target_arch = "wasm32")), + arg(long, default_value = "false") + )] /// Rebase the scale using lookup table for division instead of using a range check pub div_rebasing: bool, /// Should constants with 0.0 fraction be rebased to scale 0 - #[arg(long, default_value = "false")] + #[cfg_attr( + all(feature = "ezkl", not(target_arch = "wasm32")), + arg(long, default_value = "false") + )] pub rebase_frac_zero_constants: bool, /// check mode (safe, unsafe, etc) - #[arg(long, default_value = "unsafe", value_hint = clap::ValueHint::Other)] + #[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "unsafe", value_hint = clap::ValueHint::Other))] pub check_mode: CheckMode, /// commitment scheme - #[arg(long, default_value = "kzg", value_hint = clap::ValueHint::Other)] + #[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "kzg", value_hint = clap::ValueHint::Other))] pub commitment: Option, /// the base used for decompositions - #[arg(long, default_value = "16384", value_hint = clap::ValueHint::Other)] + #[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "16384", value_hint = clap::ValueHint::Other))] pub decomp_base: usize, - #[arg(long, default_value = "2", value_hint = clap::ValueHint::Other)] + #[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long, default_value = "2", value_hint = clap::ValueHint::Other))] /// the number of legs used for decompositions pub decomp_legs: usize, } @@ -354,6 +387,7 @@ impl RunArgs { } /// Parse a single key-value pair +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] fn parse_key_val( s: &str, ) -> Result<(T, U), Box> diff --git a/src/pfsys/evm/aggregation_kzg.rs b/src/pfsys/evm/aggregation_kzg.rs index addd4b390..8728a81f9 100644 --- a/src/pfsys/evm/aggregation_kzg.rs +++ b/src/pfsys/evm/aggregation_kzg.rs @@ -1,7 +1,7 @@ -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use crate::graph::CircuitSize; use crate::pfsys::{Snark, SnarkWitness}; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use colored_json::ToColoredJson; use halo2_proofs::circuit::AssignedCell; use halo2_proofs::plonk::{self}; @@ -20,7 +20,7 @@ use halo2_wrong_ecc::{ use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}; use halo2curves::ff::PrimeField; use itertools::Itertools; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use log::debug; use log::trace; use rand::rngs::OsRng; @@ -200,7 +200,7 @@ impl AggregationConfig { let range_config = RangeChip::::configure(meta, &main_gate_config, composition_bits, overflow_bits); - #[cfg(not(target_arch = "wasm32"))] + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] { let circuit_size = CircuitSize::from_cs(meta, 23); diff --git a/src/pfsys/mod.rs b/src/pfsys/mod.rs index a83c11138..9acee1a0a 100644 --- a/src/pfsys/mod.rs +++ b/src/pfsys/mod.rs @@ -13,6 +13,7 @@ use crate::circuit::CheckMode; use crate::graph::GraphWitness; use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript; use crate::{Commitments, EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT}; +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use clap::ValueEnum; use halo2_proofs::circuit::Value; use halo2_proofs::plonk::{ @@ -42,6 +43,7 @@ use std::io::{self, BufReader, BufWriter, Cursor, Write}; use std::ops::Deref; use std::path::PathBuf; use thiserror::Error as thisError; +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] use tosubcommand::ToFlags; use halo2curves::bn256::{Bn256, Fr, G1Affine}; @@ -56,8 +58,10 @@ fn serde_format_from_str(s: &str) -> halo2_proofs::SerdeFormat { } #[allow(missing_docs)] -#[derive( - ValueEnum, Copy, Clone, Default, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd, +#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd)] +#[cfg_attr( + all(feature = "ezkl", not(target_arch = "wasm32")), + derive(ValueEnum) )] pub enum ProofType { #[default] @@ -77,7 +81,7 @@ impl std::fmt::Display for ProofType { ) } } - +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] impl ToFlags for ProofType { fn to_flags(&self) -> Vec { vec![format!("{}", self)] @@ -129,17 +133,38 @@ impl<'source> pyo3::FromPyObject<'source> for ProofType { } #[allow(missing_docs)] -#[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize)] +#[cfg_attr( + all(feature = "ezkl", not(target_arch = "wasm32")), + derive(ValueEnum) +)] pub enum StrategyType { Single, Accum, } impl std::fmt::Display for StrategyType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.to_possible_value() - .expect("no values are skipped") - .get_name() - .fmt(f) + // When the `ezkl` feature is disabled or we're targeting `wasm32`, use basic string representation. + #[cfg(any(not(feature = "ezkl"), target_arch = "wasm32"))] + { + write!( + f, + "{}", + match self { + StrategyType::Single => "single", + StrategyType::Accum => "accum", + } + ) + } + + // When the `ezkl` feature is enabled and we're not targeting `wasm32`, use `to_possible_value`. + #[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] + { + self.to_possible_value() + .expect("no values are skipped") + .get_name() + .fmt(f) + } } } #[cfg(feature = "python-bindings")] @@ -177,8 +202,10 @@ pub enum PfSysError { } #[allow(missing_docs)] -#[derive( - ValueEnum, Default, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd, +#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize, PartialOrd)] +#[cfg_attr( + all(feature = "ezkl", not(target_arch = "wasm32")), + derive(ValueEnum) )] pub enum TranscriptType { Poseidon, @@ -198,7 +225,7 @@ impl std::fmt::Display for TranscriptType { ) } } - +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] impl ToFlags for TranscriptType { fn to_flags(&self) -> Vec { vec![format!("{}", self)] @@ -862,7 +889,7 @@ pub fn save_params( //////////////////////// #[cfg(test)] -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] mod tests { use super::*; diff --git a/src/wasm.rs b/src/wasm.rs deleted file mode 100644 index d3cff44f7..000000000 --- a/src/wasm.rs +++ /dev/null @@ -1,793 +0,0 @@ -use crate::{ - circuit::{ - modules::{ - polycommit::PolyCommitChip, - poseidon::{ - spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH}, - PoseidonChip, - }, - Module, - }, - region::RegionSettings, - }, - fieldutils::{felt_to_integer_rep, integer_rep_to_felt}, - graph::{ - modules::POSEIDON_LEN_GRAPH, quantize_float, scale_to_multiplier, GraphCircuit, - GraphSettings, - }, - pfsys::{ - create_proof_circuit, - evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript}, - verify_proof_circuit, TranscriptType, - }, - tensor::TensorType, - CheckMode, Commitments, -}; -use console_error_panic_hook; -use halo2_proofs::{ - plonk::*, - poly::{ - commitment::{CommitmentScheme, ParamsProver}, - ipa::{ - commitment::{IPACommitmentScheme, ParamsIPA}, - multiopen::{ProverIPA, VerifierIPA}, - strategy::SingleStrategy as IPASingleStrategy, - }, - kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy as KZGSingleStrategy, - }, - VerificationStrategy, - }, -}; -use halo2_solidity_verifier::encode_calldata; -use halo2curves::{ - bn256::{Bn256, Fr, G1Affine}, - ff::{FromUniformBytes, PrimeField}, -}; -use snark_verifier::{loader::native::NativeLoader, system::halo2::transcript::evm::EvmTranscript}; -use std::str::FromStr; -use wasm_bindgen::prelude::*; -use wasm_bindgen_console_logger::DEFAULT_LOGGER; - -#[cfg(feature = "web")] -pub use wasm_bindgen_rayon::init_thread_pool; - -#[wasm_bindgen] -/// Initialize logger for wasm -pub fn init_logger() { - log::set_logger(&DEFAULT_LOGGER).unwrap(); -} - -#[wasm_bindgen] -/// Initialize panic hook for wasm -pub fn init_panic_hook() { - console_error_panic_hook::set_once(); -} - -/// Wrapper around the halo2 encode call data method -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn encodeVerifierCalldata( - proof: wasm_bindgen::Clamped>, - vk_address: Option>, -) -> Result, JsError> { - let snark: crate::pfsys::Snark = serde_json::from_slice(&proof[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?; - - let vk_address: Option<[u8; 20]> = if let Some(vk_address) = vk_address { - let array: [u8; 20] = serde_json::from_slice(&vk_address[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize vk address: {}", e)))?; - Some(array) - } else { - None - }; - - let flattened_instances = snark.instances.into_iter().flatten(); - - let encoded = encode_calldata( - vk_address, - &snark.proof, - &flattened_instances.collect::>(), - ); - - Ok(encoded) -} - -/// Converts a hex string to a byte array -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn feltToBigEndian(array: wasm_bindgen::Clamped>) -> Result { - let felt: Fr = serde_json::from_slice(&array[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?; - Ok(format!("{:?}", felt)) -} - -/// Converts a felt to a little endian string -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn feltToLittleEndian(array: wasm_bindgen::Clamped>) -> Result { - let felt: Fr = serde_json::from_slice(&array[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?; - let repr = serde_json::to_string(&felt).unwrap(); - let b: String = serde_json::from_str(&repr).unwrap(); - Ok(b) -} - -/// Converts a hex string to a byte array -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn feltToInt( - array: wasm_bindgen::Clamped>, -) -> Result>, JsError> { - let felt: Fr = serde_json::from_slice(&array[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?; - Ok(wasm_bindgen::Clamped( - serde_json::to_vec(&felt_to_integer_rep(felt)) - .map_err(|e| JsError::new(&format!("Failed to serialize integer: {}", e)))?, - )) -} - -/// Converts felts to a floating point element -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn feltToFloat( - array: wasm_bindgen::Clamped>, - scale: crate::Scale, -) -> Result { - let felt: Fr = serde_json::from_slice(&array[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?; - let int_rep = felt_to_integer_rep(felt); - let multiplier = scale_to_multiplier(scale); - Ok(int_rep as f64 / multiplier) -} - -/// Converts a floating point number to a hex string representing a fixed point field element -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn floatToFelt( - input: f64, - scale: crate::Scale, -) -> Result>, JsError> { - let int_rep = - quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?; - let felt = integer_rep_to_felt(int_rep); - let vec = crate::pfsys::field_to_string::(&felt); - Ok(wasm_bindgen::Clamped(serde_json::to_vec(&vec).map_err( - |e| JsError::new(&format!("Failed to serialize a float to felt{}", e)), - )?)) -} - -/// Generate a kzg commitment. -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn kzgCommit( - message: wasm_bindgen::Clamped>, - vk: wasm_bindgen::Clamped>, - settings: wasm_bindgen::Clamped>, - params_ser: wasm_bindgen::Clamped>, -) -> Result>, JsError> { - let message: Vec = serde_json::from_slice(&message[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?; - - let mut reader = std::io::BufReader::new(¶ms_ser[..]); - let params: ParamsKZG = - halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader) - .map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?; - - let mut reader = std::io::BufReader::new(&vk[..]); - let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize settings: {}", e)))?; - let vk = VerifyingKey::::read::<_, GraphCircuit>( - &mut reader, - halo2_proofs::SerdeFormat::RawBytes, - circuit_settings, - ) - .map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?; - - let output = PolyCommitChip::commit::>( - message, - (vk.cs().blinding_factors() + 1) as u32, - ¶ms, - ); - - Ok(wasm_bindgen::Clamped( - serde_json::to_vec(&output).map_err(|e| JsError::new(&format!("{}", e)))?, - )) -} - -/// Converts a buffer to vector of 4 u64s representing a fixed point field element -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn bufferToVecOfFelt( - buffer: wasm_bindgen::Clamped>, -) -> Result>, JsError> { - // Convert the buffer to a slice - let buffer: &[u8] = &buffer; - - // Divide the buffer into chunks of 64 bytes - let chunks = buffer.chunks_exact(16); - - // Get the remainder - let remainder = chunks.remainder(); - - // Add 0s to the remainder to make it 64 bytes - let mut remainder = remainder.to_vec(); - - // Collect chunks into a Vec<[u8; 16]>. - let chunks: Result, JsError> = chunks - .map(|slice| { - let array: [u8; 16] = slice - .try_into() - .map_err(|_| JsError::new("failed to slice input chunks"))?; - Ok(array) - }) - .collect(); - - let mut chunks = chunks?; - - if remainder.len() != 0 { - remainder.resize(16, 0); - // Convert the Vec to [u8; 16] - let remainder_array: [u8; 16] = remainder - .try_into() - .map_err(|_| JsError::new("failed to slice remainder"))?; - // append the remainder to the chunks - chunks.push(remainder_array); - } - - // Convert each chunk to a field element - let field_elements: Vec = chunks - .iter() - .map(|x| PrimeField::from_u128(u8_array_to_u128_le(*x))) - .collect(); - - Ok(wasm_bindgen::Clamped( - serde_json::to_vec(&field_elements) - .map_err(|e| JsError::new(&format!("Failed to serialize field elements: {}", e)))?, - )) -} - -/// Generate a poseidon hash in browser. Input message -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn poseidonHash( - message: wasm_bindgen::Clamped>, -) -> Result>, JsError> { - let message: Vec = serde_json::from_slice(&message[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize message: {}", e)))?; - - let output = - PoseidonChip::::run( - message.clone(), - ) - .map_err(|e| JsError::new(&format!("{}", e)))?; - - Ok(wasm_bindgen::Clamped(serde_json::to_vec(&output).map_err( - |e| JsError::new(&format!("Failed to serialize poseidon hash output: {}", e)), - )?)) -} - -/// Generate a witness file from input.json, compiled model and a settings.json file. -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn genWitness( - compiled_circuit: wasm_bindgen::Clamped>, - input: wasm_bindgen::Clamped>, -) -> Result, JsError> { - let mut circuit: crate::graph::GraphCircuit = bincode::deserialize(&compiled_circuit[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize compiled model: {}", e)))?; - let input: crate::graph::input::GraphData = serde_json::from_slice(&input[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize input: {}", e)))?; - - let mut input = circuit - .load_graph_input(&input) - .map_err(|e| JsError::new(&format!("{}", e)))?; - - let witness = circuit - .forward::>( - &mut input, - None, - None, - RegionSettings::all_true( - circuit.settings().run_args.decomp_base, - circuit.settings().run_args.decomp_legs, - ), - ) - .map_err(|e| JsError::new(&format!("{}", e)))?; - - serde_json::to_vec(&witness) - .map_err(|e| JsError::new(&format!("Failed to serialize witness: {}", e))) -} - -/// Generate verifying key in browser -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn genVk( - compiled_circuit: wasm_bindgen::Clamped>, - params_ser: wasm_bindgen::Clamped>, - compress_selectors: bool, -) -> Result, JsError> { - // Read in kzg params - let mut reader = std::io::BufReader::new(¶ms_ser[..]); - let params: ParamsKZG = - halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader) - .map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?; - // Read in compiled circuit - let circuit: crate::graph::GraphCircuit = bincode::deserialize(&compiled_circuit[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize compiled model: {}", e)))?; - - // Create verifying key - let vk = create_vk_wasm::, Fr, GraphCircuit>( - &circuit, - ¶ms, - compress_selectors, - ) - .map_err(Box::::from) - .map_err(|e| JsError::new(&format!("Failed to create verifying key: {}", e)))?; - - let mut serialized_vk = Vec::new(); - vk.write(&mut serialized_vk, halo2_proofs::SerdeFormat::RawBytes) - .map_err(|e| JsError::new(&format!("Failed to serialize vk: {}", e)))?; - - Ok(serialized_vk) -} - -/// Generate proving key in browser -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn genPk( - vk: wasm_bindgen::Clamped>, - compiled_circuit: wasm_bindgen::Clamped>, - params_ser: wasm_bindgen::Clamped>, -) -> Result, JsError> { - // Read in kzg params - let mut reader = std::io::BufReader::new(¶ms_ser[..]); - let params: ParamsKZG = - halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader) - .map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?; - // Read in compiled circuit - let circuit: crate::graph::GraphCircuit = bincode::deserialize(&compiled_circuit[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize compiled model: {}", e)))?; - - // Read in verifying key - let mut reader = std::io::BufReader::new(&vk[..]); - let vk = VerifyingKey::::read::<_, GraphCircuit>( - &mut reader, - halo2_proofs::SerdeFormat::RawBytes, - circuit.settings().clone(), - ) - .map_err(|e| JsError::new(&format!("Failed to deserialize verifying key: {}", e)))?; - // Create proving key - let pk = create_pk_wasm::, Fr, GraphCircuit>(vk, &circuit, ¶ms) - .map_err(Box::::from) - .map_err(|e| JsError::new(&format!("Failed to create proving key: {}", e)))?; - - let mut serialized_pk = Vec::new(); - pk.write(&mut serialized_pk, halo2_proofs::SerdeFormat::RawBytes) - .map_err(|e| JsError::new(&format!("Failed to serialize pk: {}", e)))?; - - Ok(serialized_pk) -} - -/// Verify proof in browser using wasm -#[wasm_bindgen] -pub fn verify( - proof_js: wasm_bindgen::Clamped>, - vk: wasm_bindgen::Clamped>, - settings: wasm_bindgen::Clamped>, - srs: wasm_bindgen::Clamped>, -) -> Result { - let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize settings: {}", e)))?; - - let proof: crate::pfsys::Snark = serde_json::from_slice(&proof_js[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?; - - let mut reader = std::io::BufReader::new(&vk[..]); - let vk = VerifyingKey::::read::<_, GraphCircuit>( - &mut reader, - halo2_proofs::SerdeFormat::RawBytes, - circuit_settings.clone(), - ) - .map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?; - - let orig_n = 1 << circuit_settings.run_args.logrows; - - let commitment = circuit_settings.run_args.commitment.into(); - - let mut reader = std::io::BufReader::new(&srs[..]); - let result = match commitment { - Commitments::KZG => { - let params: ParamsKZG = - halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader) - .map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?; - let strategy = KZGSingleStrategy::new(params.verifier_params()); - match proof.transcript_type { - TranscriptType::EVM => verify_proof_circuit::< - VerifierSHPLONK<'_, Bn256>, - KZGCommitmentScheme, - KZGSingleStrategy<_>, - _, - EvmTranscript, - >(&proof, ¶ms, &vk, strategy, orig_n), - - TranscriptType::Poseidon => { - verify_proof_circuit::< - VerifierSHPLONK<'_, Bn256>, - KZGCommitmentScheme, - KZGSingleStrategy<_>, - _, - PoseidonTranscript, - >(&proof, ¶ms, &vk, strategy, orig_n) - } - } - } - Commitments::IPA => { - let params: ParamsIPA<_> = - halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader) - .map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?; - let strategy = IPASingleStrategy::new(params.verifier_params()); - match proof.transcript_type { - TranscriptType::EVM => verify_proof_circuit::< - VerifierIPA<_>, - IPACommitmentScheme, - IPASingleStrategy<_>, - _, - EvmTranscript, - >(&proof, ¶ms, &vk, strategy, orig_n), - TranscriptType::Poseidon => { - verify_proof_circuit::< - VerifierIPA<_>, - IPACommitmentScheme, - IPASingleStrategy<_>, - _, - PoseidonTranscript, - >(&proof, ¶ms, &vk, strategy, orig_n) - } - } - } - }; - - match result { - Ok(_) => Ok(true), - Err(e) => Err(JsError::new(&format!("{}", e))), - } -} - -#[wasm_bindgen] -#[allow(non_snake_case)] -/// Verify aggregate proof in browser using wasm -pub fn verifyAggr( - proof_js: wasm_bindgen::Clamped>, - vk: wasm_bindgen::Clamped>, - logrows: u64, - srs: wasm_bindgen::Clamped>, - commitment: &str, -) -> Result { - let proof: crate::pfsys::Snark = serde_json::from_slice(&proof_js[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?; - - let mut reader = std::io::BufReader::new(&vk[..]); - let vk = VerifyingKey::::read::<_, AggregationCircuit>( - &mut reader, - halo2_proofs::SerdeFormat::RawBytes, - (), - ) - .map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?; - - let commit = Commitments::from_str(commitment).map_err(|e| JsError::new(&format!("{}", e)))?; - - let orig_n = 1 << logrows; - - let mut reader = std::io::BufReader::new(&srs[..]); - let result = match commit { - Commitments::KZG => { - let params: ParamsKZG = - halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader) - .map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?; - let strategy = KZGSingleStrategy::new(params.verifier_params()); - match proof.transcript_type { - TranscriptType::EVM => verify_proof_circuit::< - VerifierSHPLONK<'_, Bn256>, - KZGCommitmentScheme, - KZGSingleStrategy<_>, - _, - EvmTranscript, - >(&proof, ¶ms, &vk, strategy, orig_n), - - TranscriptType::Poseidon => { - verify_proof_circuit::< - VerifierSHPLONK<'_, Bn256>, - KZGCommitmentScheme, - KZGSingleStrategy<_>, - _, - PoseidonTranscript, - >(&proof, ¶ms, &vk, strategy, orig_n) - } - } - } - Commitments::IPA => { - let params: ParamsIPA<_> = - halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader) - .map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?; - let strategy = IPASingleStrategy::new(params.verifier_params()); - match proof.transcript_type { - TranscriptType::EVM => verify_proof_circuit::< - VerifierIPA<_>, - IPACommitmentScheme, - IPASingleStrategy<_>, - _, - EvmTranscript, - >(&proof, ¶ms, &vk, strategy, orig_n), - TranscriptType::Poseidon => { - verify_proof_circuit::< - VerifierIPA<_>, - IPACommitmentScheme, - IPASingleStrategy<_>, - _, - PoseidonTranscript, - >(&proof, ¶ms, &vk, strategy, orig_n) - } - } - } - }; - - match result { - Ok(_) => Ok(true), - Err(e) => Err(JsError::new(&format!("{}", e))), - } -} - -/// Prove in browser using wasm -#[wasm_bindgen] -pub fn prove( - witness: wasm_bindgen::Clamped>, - pk: wasm_bindgen::Clamped>, - compiled_circuit: wasm_bindgen::Clamped>, - srs: wasm_bindgen::Clamped>, -) -> Result, JsError> { - #[cfg(feature = "det-prove")] - log::set_max_level(log::LevelFilter::Debug); - #[cfg(not(feature = "det-prove"))] - log::set_max_level(log::LevelFilter::Info); - - // read in circuit - let mut circuit: crate::graph::GraphCircuit = bincode::deserialize(&compiled_circuit[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize circuit: {}", e)))?; - - // read in model input - let data: crate::graph::GraphWitness = serde_json::from_slice(&witness[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize witness: {}", e)))?; - - // read in proving key - let mut reader = std::io::BufReader::new(&pk[..]); - let pk = ProvingKey::::read::<_, GraphCircuit>( - &mut reader, - halo2_proofs::SerdeFormat::RawBytes, - circuit.settings().clone(), - ) - .map_err(|e| JsError::new(&format!("Failed to deserialize proving key: {}", e)))?; - - // prep public inputs - circuit - .load_graph_witness(&data) - .map_err(|e| JsError::new(&format!("{}", e)))?; - let public_inputs = circuit - .prepare_public_inputs(&data) - .map_err(|e| JsError::new(&format!("{}", e)))?; - let proof_split_commits: Option = data.into(); - - // read in kzg params - let mut reader = std::io::BufReader::new(&srs[..]); - let commitment = circuit.settings().run_args.commitment.into(); - // creates and verifies the proof - let proof = match commitment { - Commitments::KZG => { - let params: ParamsKZG = - halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader) - .map_err(|e| JsError::new(&format!("Failed to deserialize srs: {}", e)))?; - - create_proof_circuit::< - KZGCommitmentScheme, - _, - ProverSHPLONK<_>, - VerifierSHPLONK<_>, - KZGSingleStrategy<_>, - _, - EvmTranscript<_, _, _, _>, - EvmTranscript<_, _, _, _>, - >( - circuit, - vec![public_inputs], - ¶ms, - &pk, - CheckMode::UNSAFE, - crate::Commitments::KZG, - TranscriptType::EVM, - proof_split_commits, - None, - ) - } - Commitments::IPA => { - let params: ParamsIPA<_> = - halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader) - .map_err(|e| JsError::new(&format!("Failed to deserialize srs: {}", e)))?; - - create_proof_circuit::< - IPACommitmentScheme, - _, - ProverIPA<_>, - VerifierIPA<_>, - IPASingleStrategy<_>, - _, - EvmTranscript<_, _, _, _>, - EvmTranscript<_, _, _, _>, - >( - circuit, - vec![public_inputs], - ¶ms, - &pk, - CheckMode::UNSAFE, - crate::Commitments::IPA, - TranscriptType::EVM, - proof_split_commits, - None, - ) - } - } - .map_err(|e| JsError::new(&format!("{}", e)))?; - - Ok(serde_json::to_string(&proof) - .map_err(|e| JsError::new(&format!("{}", e)))? - .into_bytes()) -} - -// VALIDATION FUNCTIONS - -/// Witness file validation -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn witnessValidation(witness: wasm_bindgen::Clamped>) -> Result { - let _: crate::graph::GraphWitness = serde_json::from_slice(&witness[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize witness: {}", e)))?; - - Ok(true) -} -/// Compiled circuit validation -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn compiledCircuitValidation( - compiled_circuit: wasm_bindgen::Clamped>, -) -> Result { - let _: crate::graph::GraphCircuit = bincode::deserialize(&compiled_circuit[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize compiled circuit: {}", e)))?; - - Ok(true) -} -/// Input file validation -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn inputValidation(input: wasm_bindgen::Clamped>) -> Result { - let _: crate::graph::input::GraphData = serde_json::from_slice(&input[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize input: {}", e)))?; - - Ok(true) -} -/// Proof file validation -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn proofValidation(proof: wasm_bindgen::Clamped>) -> Result { - let _: crate::pfsys::Snark = serde_json::from_slice(&proof[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?; - - Ok(true) -} -/// Vk file validation -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn vkValidation( - vk: wasm_bindgen::Clamped>, - settings: wasm_bindgen::Clamped>, -) -> Result { - let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize settings: {}", e)))?; - let mut reader = std::io::BufReader::new(&vk[..]); - let _ = VerifyingKey::::read::<_, GraphCircuit>( - &mut reader, - halo2_proofs::SerdeFormat::RawBytes, - circuit_settings, - ) - .map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?; - - Ok(true) -} -/// Pk file validation -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn pkValidation( - pk: wasm_bindgen::Clamped>, - settings: wasm_bindgen::Clamped>, -) -> Result { - let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize settings: {}", e)))?; - let mut reader = std::io::BufReader::new(&pk[..]); - let _ = ProvingKey::::read::<_, GraphCircuit>( - &mut reader, - halo2_proofs::SerdeFormat::RawBytes, - circuit_settings, - ) - .map_err(|e| JsError::new(&format!("Failed to deserialize proving key: {}", e)))?; - - Ok(true) -} -/// Settings file validation -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn settingsValidation(settings: wasm_bindgen::Clamped>) -> Result { - let _: GraphSettings = serde_json::from_slice(&settings[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize settings: {}", e)))?; - - Ok(true) -} -/// Srs file validation -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn srsValidation(srs: wasm_bindgen::Clamped>) -> Result { - let mut reader = std::io::BufReader::new(&srs[..]); - let _: ParamsKZG = - halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader) - .map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?; - - Ok(true) -} - -// HELPER FUNCTIONS - -/// Creates a [ProvingKey] for a [GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`) for the WASM target -#[cfg(target_arch = "wasm32")] -pub fn create_vk_wasm>( - circuit: &C, - params: &'_ Scheme::ParamsProver, - compress_selectors: bool, -) -> Result, halo2_proofs::plonk::Error> -where - C: Circuit, - ::Scalar: FromUniformBytes<64>, -{ - // Real proof - let empty_circuit = >::without_witnesses(circuit); - - // Initialize the verifying key - let vk = keygen_vk_custom(params, &empty_circuit, compress_selectors)?; - Ok(vk) -} -/// Creates a [ProvingKey] from a [VerifyingKey] for a [GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`) for the WASM target -#[cfg(target_arch = "wasm32")] -pub fn create_pk_wasm>( - vk: VerifyingKey, - circuit: &C, - params: &'_ Scheme::ParamsProver, -) -> Result, halo2_proofs::plonk::Error> -where - C: Circuit, - ::Scalar: FromUniformBytes<64>, -{ - // Real proof - let empty_circuit = >::without_witnesses(circuit); - - // Initialize the proving key - let pk = keygen_pk(params, vk, &empty_circuit)?; - Ok(pk) -} - -/// -pub fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 { - let mut n: u128 = 0; - for &b in arr.iter().rev() { - n <<= 8; - n |= b as u128; - } - n -} diff --git a/tests/wasm/calibration.json b/tests/assets/calibration.json similarity index 100% rename from tests/wasm/calibration.json rename to tests/assets/calibration.json diff --git a/tests/wasm/input.json b/tests/assets/input.json similarity index 100% rename from tests/wasm/input.json rename to tests/assets/input.json diff --git a/tests/wasm/kzg b/tests/assets/kzg similarity index 100% rename from tests/wasm/kzg rename to tests/assets/kzg diff --git a/tests/wasm/kzg1.srs b/tests/assets/kzg1.srs similarity index 100% rename from tests/wasm/kzg1.srs rename to tests/assets/kzg1.srs diff --git a/tests/wasm/model.compiled b/tests/assets/model.compiled similarity index 100% rename from tests/wasm/model.compiled rename to tests/assets/model.compiled diff --git a/tests/wasm/network.onnx b/tests/assets/network.onnx similarity index 100% rename from tests/wasm/network.onnx rename to tests/assets/network.onnx diff --git a/tests/wasm/pk.key b/tests/assets/pk.key similarity index 100% rename from tests/wasm/pk.key rename to tests/assets/pk.key diff --git a/tests/wasm/proof.json b/tests/assets/proof.json similarity index 100% rename from tests/wasm/proof.json rename to tests/assets/proof.json diff --git a/tests/wasm/proof_aggr.json b/tests/assets/proof_aggr.json similarity index 100% rename from tests/wasm/proof_aggr.json rename to tests/assets/proof_aggr.json diff --git a/tests/wasm/settings.json b/tests/assets/settings.json similarity index 100% rename from tests/wasm/settings.json rename to tests/assets/settings.json diff --git a/tests/wasm/vk.key b/tests/assets/vk.key similarity index 100% rename from tests/wasm/vk.key rename to tests/assets/vk.key diff --git a/tests/wasm/vk_aggr.key b/tests/assets/vk_aggr.key similarity index 100% rename from tests/wasm/vk_aggr.key rename to tests/assets/vk_aggr.key diff --git a/tests/wasm/witness.json b/tests/assets/witness.json similarity index 100% rename from tests/wasm/witness.json rename to tests/assets/witness.json diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 30eb731ac..ff7752dbe 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -1,4 +1,4 @@ -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] #[cfg(test)] mod native_tests { diff --git a/tests/ios/can_verify_aggr.swift b/tests/ios/can_verify_aggr.swift new file mode 100644 index 000000000..b5986058c --- /dev/null +++ b/tests/ios/can_verify_aggr.swift @@ -0,0 +1,39 @@ +// Write a simple swift test +import ezkl +import Foundation + +let pathToFile = "../../../../tests/assets/" + + +func loadFileAsBytes(from path: String) -> Data? { + let url = URL(fileURLWithPath: path) + return try? Data(contentsOf: url) +} + +do { + let proofAggrPath = pathToFile + "proof_aggr.json" + let vkAggrPath = pathToFile + "vk_aggr.key" + let srs1Path = pathToFile + "kzg1.srs" + + guard let proofAggr = loadFileAsBytes(from: proofAggrPath) else { + fatalError("Failed to load proofAggr file") + } + guard let vkAggr = loadFileAsBytes(from: vkAggrPath) else { + fatalError("Failed to load vkAggr file") + } + guard let srs1 = loadFileAsBytes(from: srs1Path) else { + fatalError("Failed to load srs1 file") + } + + let value = try verifyAggr( + proof: proofAggr, + vk: vkAggr, + logrows: 21, + srs: srs1, + commitment: "kzg" + ) + + // should not fail + assert(value == true, "Failed the test") + +} \ No newline at end of file diff --git a/tests/ios/gen_pk_test.swift b/tests/ios/gen_pk_test.swift new file mode 100644 index 000000000..d51173fec --- /dev/null +++ b/tests/ios/gen_pk_test.swift @@ -0,0 +1,42 @@ +// Swift version of gen_pk_test +import ezkl +import Foundation + +func loadFileAsBytes(from path: String) -> Data? { + let url = URL(fileURLWithPath: path) + return try? Data(contentsOf: url) +} + +do { + let pathToFile = "../../../../tests/assets/" + let networkCompiledPath = pathToFile + "model.compiled" + let srsPath = pathToFile + "kzg" + + // Load necessary files + guard let compiledCircuit = loadFileAsBytes(from: networkCompiledPath) else { + fatalError("Failed to load network compiled file") + } + guard let srs = loadFileAsBytes(from: srsPath) else { + fatalError("Failed to load SRS file") + } + + // Generate the vk (Verifying Key) + let vk = try genVk( + compiledCircuit: compiledCircuit, + srs: srs, + compressSelectors: true // Corresponds to the `true` boolean in the Rust code + ) + + // Generate the pk (Proving Key) + let pk = try genPk( + vk: vk, + compiledCircuit: compiledCircuit, + srs: srs + ) + + // Ensure that the proving key is not empty + assert(pk.count > 0, "Proving key generation failed, pk is empty") + +} catch { + fatalError("Test failed with error: \(error)") +} \ No newline at end of file diff --git a/tests/ios/gen_vk_test.swift b/tests/ios/gen_vk_test.swift new file mode 100644 index 000000000..51b10be8a --- /dev/null +++ b/tests/ios/gen_vk_test.swift @@ -0,0 +1,35 @@ +// Swift version of gen_vk_test +import ezkl +import Foundation + +func loadFileAsBytes(from path: String) -> Data? { + let url = URL(fileURLWithPath: path) + return try? Data(contentsOf: url) +} + +do { + let pathToFile = "../../../../tests/assets/" + let networkCompiledPath = pathToFile + "model.compiled" + let srsPath = pathToFile + "kzg" + + // Load necessary files + guard let compiledCircuit = loadFileAsBytes(from: networkCompiledPath) else { + fatalError("Failed to load network compiled file") + } + guard let srs = loadFileAsBytes(from: srsPath) else { + fatalError("Failed to load SRS file") + } + + // Generate the vk (Verifying Key) + let vk = try genVk( + compiledCircuit: compiledCircuit, + srs: srs, + compressSelectors: true // Corresponds to the `true` boolean in the Rust code + ) + + // Ensure that the verifying key is not empty + assert(vk.count > 0, "Verifying key generation failed, vk is empty") + +} catch { + fatalError("Test failed with error: \(error)") +} \ No newline at end of file diff --git a/tests/ios/pk_is_valid_test.swift b/tests/ios/pk_is_valid_test.swift new file mode 100644 index 000000000..deffdc18a --- /dev/null +++ b/tests/ios/pk_is_valid_test.swift @@ -0,0 +1,69 @@ +// Swift version of pk_is_valid_test +import ezkl +import Foundation + +func loadFileAsBytes(from path: String) -> Data? { + let url = URL(fileURLWithPath: path) + return try? Data(contentsOf: url) +} + +do { + let pathToFile = "../../../../tests/assets/" + let networkCompiledPath = pathToFile + "model.compiled" + let srsPath = pathToFile + "kzg" + let witnessPath = pathToFile + "witness.json" + let settingsPath = pathToFile + "settings.json" + + // Load necessary files + guard let compiledCircuit = loadFileAsBytes(from: networkCompiledPath) else { + fatalError("Failed to load network compiled file") + } + guard let srs = loadFileAsBytes(from: srsPath) else { + fatalError("Failed to load SRS file") + } + guard let witness = loadFileAsBytes(from: witnessPath) else { + fatalError("Failed to load witness file") + } + guard let settings = loadFileAsBytes(from: settingsPath) else { + fatalError("Failed to load settings file") + } + + // Generate the vk (Verifying Key) + let vk = try genVk( + compiledCircuit: compiledCircuit, + srs: srs, + compressSelectors: true // Corresponds to the `true` boolean in the Rust code + ) + + // Generate the pk (Proving Key) + let pk = try genPk( + vk: vk, + compiledCircuit: compiledCircuit, + srs: srs + ) + + // Prove using the witness and proving key + let proof = try prove( + witness: witness, + pk: pk, + compiledCircuit: compiledCircuit, + srs: srs + ) + + // Ensure that the proof is not empty + assert(proof.count > 0, "Proof generation failed, proof is empty") + + // Verify the proof + let value = try verify( + proof: proof, + vk: vk, + settings: settings, + srs: srs + ) + + // Ensure that the verification passed + assert(value == true, "Verification failed") + +} catch { + fatalError("Test failed with error: \(error)") +} \ No newline at end of file diff --git a/tests/ios/verify_encode_verifier_calldata.swift b/tests/ios/verify_encode_verifier_calldata.swift new file mode 100644 index 000000000..db9308a81 --- /dev/null +++ b/tests/ios/verify_encode_verifier_calldata.swift @@ -0,0 +1,71 @@ +// Swift version of verify_encode_verifier_calldata test +import ezkl +import Foundation + +func loadFileAsBytes(from path: String) -> Data? { + let url = URL(fileURLWithPath: path) + return try? Data(contentsOf: url) +} + +do { + let pathToFile = "../../../../tests/assets/" + let proofPath = pathToFile + "proof.json" + + guard let proof = loadFileAsBytes(from: proofPath) else { + fatalError("Failed to load proof file") + } + + // Test without vk address + let calldataNoVk = try encodeVerifierCalldata( + proof: proof, + vkAddress: nil + ) + + // Deserialize the proof data + struct Snark: Decodable { + let proof: Data + let instances: Data + } + + let snark = try JSONDecoder().decode(Snark.self, from: proof) + + let flattenedInstances = snark.instances.flatMap { $0 } + let referenceCalldataNoVk = try encodeCalldata( + vk: nil, + proof: snark.proof, + instances: flattenedInstances + ) + + // Check if the encoded calldata matches the reference + assert(calldataNoVk == referenceCalldataNoVk, "Calldata without vk does not match") + + // Test with vk address + let vkAddressString = "0000000000000000000000000000000000000000" + let vkAddressData = Data(hexString: vkAddressString) + + guard vkAddressData.count == 20 else { + fatalError("Invalid VK address length") + } + + let vkAddressArray = [UInt8](vkAddressData) + + // Serialize vkAddress to match JSON serialization in Rust + let serializedVkAddress = try JSONEncoder().encode(vkAddressArray) + + let calldataWithVk = try encodeVerifierCalldata( + proof: proof, + vk: serializedVkAddress + ) + + let referenceCalldataWithVk = try encodeCalldata( + vk: vkAddressArray, + proof: snark.proof, + instances: flattenedInstances + ) + + // Check if the encoded calldata matches the reference + assert(calldataWithVk == referenceCalldataWithVk, "Calldata with vk does not match") + +} catch { + fatalError("Test failed with error: \(error)") +} \ No newline at end of file diff --git a/tests/ios/verify_gen_witness.swift b/tests/ios/verify_gen_witness.swift new file mode 100644 index 000000000..894a1a7ed --- /dev/null +++ b/tests/ios/verify_gen_witness.swift @@ -0,0 +1,45 @@ +// Swift version of verify_gen_witness test +import ezkl +import Foundation + +func loadFileAsBytes(from path: String) -> Data? { + let url = URL(fileURLWithPath: path) + return try? Data(contentsOf: url) +} + +do { + let pathToFile = "../../../../tests/assets/" + let networkCompiledPath = pathToFile + "model.compiled" + let inputPath = pathToFile + "input.json" + let witnessPath = pathToFile + "witness.json" + + // Load necessary files + guard let networkCompiled = loadFileAsBytes(from: networkCompiledPath) else { + fatalError("Failed to load network compiled file") + } + guard let input = loadFileAsBytes(from: inputPath) else { + fatalError("Failed to load input file") + } + guard let referenceWitnessData = loadFileAsBytes(from: witnessPath) else { + fatalError("Failed to load witness file") + } + + // Generate witness using genWitness function + let witnessData = try genWitness( + compiledCircuit: networkCompiled, + input: input + ) + + // Deserialize the witness + struct GraphWitness: Decodable, Equatable {} + let witness = try JSONDecoder().decode(GraphWitness.self, from: witnessData) + + // Deserialize the reference witness + let referenceWitness = try JSONDecoder().decode(GraphWitness.self, from: referenceWitnessData) + + // Check if the witness matches the reference witness + assert(witness == referenceWitness, "Witnesses do not match") + +} catch { + fatalError("Test failed with error: \(error)") +} \ No newline at end of file diff --git a/tests/ios/verify_kzg_commit.swift b/tests/ios/verify_kzg_commit.swift new file mode 100644 index 000000000..0d9cc3d54 --- /dev/null +++ b/tests/ios/verify_kzg_commit.swift @@ -0,0 +1,64 @@ +// Swift version of verify_kzg_commit test +import ezkl +import Foundation + +func loadFileAsBytes(from path: String) -> Data? { + let url = URL(fileURLWithPath: path) + return try? Data(contentsOf: url) +} + +do { + let pathToFile = "../../../../tests/assets/" + let vkPath = pathToFile + "vk.key" + let srsPath = pathToFile + "kzg" + let settingsPath = pathToFile + "settings.json" + + guard let vk = loadFileAsBytes(from: vkPath) else { + fatalError("Failed to load vk file") + } + guard let srs = loadFileAsBytes(from: srsPath) else { + fatalError("Failed to load srs file") + } + guard let settings = loadFileAsBytes(from: settingsPath) else { + fatalError("Failed to load settings file") + } + + // Create a vector of field elements + var message: [UInt64] = [] + for i in 0..<32 { + message.append(UInt64(i)) + } + + // Serialize the message array + let messageData = try JSONEncoder().encode(message) + + // Deserialize settings + struct GraphSettings: Decodable {} + let settingsDecoded = try JSONDecoder().decode(GraphSettings.self, from: settings) + + // Generate commitment + let commitmentData = try kzgCommit( + message: messageData, + vk: vk, + settings: settings, + srs: srs + ) + + // Deserialize the resulting commitment + struct G1Affine: Decodable {} + let commitment = try JSONDecoder().decode([G1Affine].self, from: commitmentData) + + // Reference commitment using params and vk + // For Swift, you'd need to implement or link the corresponding methods like in Rust + let referenceCommitment = try polyCommit( + message: message, + vk: vk, + srs: srs + ) + + // Check if the commitment matches the reference + assert(commitment == referenceCommitment, "Commitments do not match") + +} catch { + fatalError("Test failed with error: \(error)") +} \ No newline at end of file diff --git a/tests/ios/verify_validations.swift b/tests/ios/verify_validations.swift new file mode 100644 index 000000000..a7497755d --- /dev/null +++ b/tests/ios/verify_validations.swift @@ -0,0 +1,103 @@ +// Swift version of verify_validations test +import ezkl +import Foundation + +func loadFileAsBytes(from path: String) -> Data? { + let url = URL(fileURLWithPath: path) + return try? Data(contentsOf: url) +} + +do { + let pathToFile = "../../../../tests/assets/" + let compiledCircuitPath = pathToFile + "model.compiled" + let networkPath = pathToFile + "network.onnx" + let witnessPath = pathToFile + "witness.json" + let inputPath = pathToFile + "input.json" + let proofPath = pathToFile + "proof.json" + let vkPath = pathToFile + "vk.key" + let pkPath = pathToFile + "pk.key" + let settingsPath = pathToFile + "settings.json" + let srsPath = pathToFile + "kzg" + + // Load necessary files + guard let compiledCircuit = loadFileAsBytes(from: compiledCircuitPath) else { + fatalError("Failed to load network compiled file") + } + guard let network = loadFileAsBytes(from: networkPath) else { + fatalError("Failed to load network file") + } + guard let witness = loadFileAsBytes(from: witnessPath) else { + fatalError("Failed to load witness file") + } + guard let input = loadFileAsBytes(from: inputPath) else { + fatalError("Failed to load input file") + } + guard let proof = loadFileAsBytes(from: proofPath) else { + fatalError("Failed to load proof file") + } + guard let vk = loadFileAsBytes(from: vkPath) else { + fatalError("Failed to load vk file") + } + guard let pk = loadFileAsBytes(from: pkPath) else { + fatalError("Failed to load pk file") + } + guard let settings = loadFileAsBytes(from: settingsPath) else { + fatalError("Failed to load settings file") + } + guard let srs = loadFileAsBytes(from: srsPath) else { + fatalError("Failed to load srs file") + } + + // Witness validation (should fail for network compiled) + let witnessValidationResult1 = try? witnessValidation(witness:compiledCircuit) + assert(witnessValidationResult1 == nil, "Witness validation should fail for network compiled") + + // Witness validation (should pass for witness) + let witnessValidationResult2 = try? witnessValidation(witness:witness) + assert(witnessValidationResult2 != nil, "Witness validation should pass for witness") + + // Compiled circuit validation (should fail for onnx network) + let circuitValidationResult1 = try? compiledCircuitValidation(compiledCircuit:network) + assert(circuitValidationResult1 == nil, "Compiled circuit validation should fail for onnx network") + + // Compiled circuit validation (should pass for compiled network) + let circuitValidationResult2 = try? compiledCircuitValidation(compiledCircuit:compiledCircuit) + assert(circuitValidationResult2 != nil, "Compiled circuit validation should pass for compiled network") + + // Input validation (should fail for witness) + let inputValidationResult1 = try? inputValidation(input:witness) + assert(inputValidationResult1 == nil, "Input validation should fail for witness") + + // Input validation (should pass for input) + let inputValidationResult2 = try? inputValidation(input:input) + assert(inputValidationResult2 != nil, "Input validation should pass for input") + + // Proof validation (should fail for witness) + let proofValidationResult1 = try? proofValidation(proof:witness) + assert(proofValidationResult1 == nil, "Proof validation should fail for witness") + + // Proof validation (should pass for proof) + let proofValidationResult2 = try? proofValidation(proof:proof) + assert(proofValidationResult2 != nil, "Proof validation should pass for proof") + + // Verifying key (vk) validation (should pass) + let vkValidationResult = try? vkValidation(vk:vk, settings:settings) + assert(vkValidationResult != nil, "VK validation should pass for vk") + + // Proving key (pk) validation (should pass) + let pkValidationResult = try? pkValidation(pk:pk, settings:settings) + assert(pkValidationResult != nil, "PK validation should pass for pk") + + // Settings validation (should fail for proof) + let settingsValidationResult1 = try? settingsValidation(settings:proof) + assert(settingsValidationResult1 == nil, "Settings validation should fail for proof") + + // Settings validation (should pass for settings) + let settingsValidationResult2 = try? settingsValidation(settings:settings) + assert(settingsValidationResult2 != nil, "Settings validation should pass for settings") + + // SRS validation (should pass) + let srsValidationResult = try? srsValidation(srs:srs) + assert(srsValidationResult != nil, "SRS validation should pass for srs") + +} \ No newline at end of file diff --git a/tests/ios_integration_tests.rs b/tests/ios_integration_tests.rs new file mode 100644 index 000000000..756851fbf --- /dev/null +++ b/tests/ios_integration_tests.rs @@ -0,0 +1,11 @@ +#[cfg(feature = "ios-bindings-test")] +uniffi::build_foreign_language_testcases!( + "tests/ios/can_verify_aggr.swift", + "tests/ios/verify_gen_witness.swift", + "tests/ios/gen_pk_test.swift", + "tests/ios/gen_vk_test.swift", + "tests/ios/pk_is_valid_test.swift", + "tests/ios/verify_validations.swift", + // "tests/ios/verify_encode_verifier_calldata.swift", // TODO - the function requires rust dependencies to test + // "tests/ios/verify_kzg_commit.swift", // TODO - the function is not exported and requires rust dependencies to test +); diff --git a/tests/py_integration_tests.rs b/tests/py_integration_tests.rs index 555e63600..3d3dcb2e6 100644 --- a/tests/py_integration_tests.rs +++ b/tests/py_integration_tests.rs @@ -1,4 +1,4 @@ -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))] #[cfg(test)] mod py_tests { diff --git a/tests/wasm.rs b/tests/wasm.rs index 3c1f9c1d8..f08a5ae06 100644 --- a/tests/wasm.rs +++ b/tests/wasm.rs @@ -1,6 +1,12 @@ #[cfg(all(target_arch = "wasm32", target_os = "unknown"))] #[cfg(test)] mod wasm32 { + use ezkl::bindings::wasm::{ + bufferToVecOfFelt, compiledCircuitValidation, encodeVerifierCalldata, feltToBigEndian, + feltToFloat, feltToInt, feltToLittleEndian, genPk, genVk, genWitness, inputValidation, + kzgCommit, pkValidation, poseidonHash, proofValidation, prove, settingsValidation, + srsValidation, u8_array_to_u128_le, verify, verifyAggr, vkValidation, witnessValidation, + }; use ezkl::circuit::modules::polycommit::PolyCommitChip; use ezkl::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH}; use ezkl::circuit::modules::poseidon::PoseidonChip; @@ -9,12 +15,6 @@ mod wasm32 { use ezkl::graph::GraphCircuit; use ezkl::graph::{GraphSettings, GraphWitness}; use ezkl::pfsys; - use ezkl::wasm::{ - bufferToVecOfFelt, compiledCircuitValidation, encodeVerifierCalldata, feltToBigEndian, - feltToFloat, feltToInt, feltToLittleEndian, genPk, genVk, genWitness, inputValidation, - kzgCommit, pkValidation, poseidonHash, proofValidation, prove, settingsValidation, - srsValidation, u8_array_to_u128_le, verify, verifyAggr, vkValidation, witnessValidation, - }; use halo2_proofs::plonk::VerifyingKey; use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme; use halo2_proofs::poly::kzg::commitment::ParamsKZG; @@ -28,18 +28,18 @@ mod wasm32 { wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); - pub const WITNESS: &[u8] = include_bytes!("../tests/wasm/witness.json"); - pub const NETWORK_COMPILED: &[u8] = include_bytes!("../tests/wasm/model.compiled"); - pub const NETWORK: &[u8] = include_bytes!("../tests/wasm/network.onnx"); - pub const INPUT: &[u8] = include_bytes!("../tests/wasm/input.json"); - pub const PROOF: &[u8] = include_bytes!("../tests/wasm/proof.json"); - pub const PROOF_AGGR: &[u8] = include_bytes!("../tests/wasm/proof_aggr.json"); - pub const SETTINGS: &[u8] = include_bytes!("../tests/wasm/settings.json"); - pub const PK: &[u8] = include_bytes!("../tests/wasm/pk.key"); - pub const VK: &[u8] = include_bytes!("../tests/wasm/vk.key"); - pub const VK_AGGR: &[u8] = include_bytes!("../tests/wasm/vk_aggr.key"); - pub const SRS: &[u8] = include_bytes!("../tests/wasm/kzg"); - pub const SRS1: &[u8] = include_bytes!("../tests/wasm/kzg1.srs"); + pub const WITNESS: &[u8] = include_bytes!("assets/witness.json"); + pub const NETWORK_COMPILED: &[u8] = include_bytes!("assets/model.compiled"); + pub const NETWORK: &[u8] = include_bytes!("assets/network.onnx"); + pub const INPUT: &[u8] = include_bytes!("assets/input.json"); + pub const PROOF: &[u8] = include_bytes!("assets/proof.json"); + pub const PROOF_AGGR: &[u8] = include_bytes!("assets/proof_aggr.json"); + pub const SETTINGS: &[u8] = include_bytes!("assets/settings.json"); + pub const PK: &[u8] = include_bytes!("assets/pk.key"); + pub const VK: &[u8] = include_bytes!("assets/vk.key"); + pub const VK_AGGR: &[u8] = include_bytes!("assets/vk_aggr.key"); + pub const SRS: &[u8] = include_bytes!("assets/kzg"); + pub const SRS1: &[u8] = include_bytes!("assets/kzg1.srs"); #[wasm_bindgen_test] async fn can_verify_aggr() {