Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ICICLE msm integration #498

Merged
merged 27 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
02db3e7
nixShell setup for icicle
alxiong Feb 16, 2024
32835ee
add basic msm benchmark
alxiong Feb 27, 2024
ef61c72
add Kzg::commit_with_gpu() and unit tests
alxiong Feb 27, 2024
10f2d30
update benchmark
alxiong Feb 27, 2024
250e6f4
add par_iter and fix bench err
alxiong Feb 28, 2024
b1f3c0c
finer-grain print trace, try to fix ci
alxiong Feb 28, 2024
514d479
split up commit_using_gpu() into 4 api
alxiong Mar 1, 2024
7096bad
use PCSError instead of unwrap
alxiong Mar 1, 2024
7ce5363
ci: explict all-features and avoid icicle feature
alxiong Mar 1, 2024
819eaf6
adjust benchmark name
alxiong Mar 1, 2024
33b4d84
Merge remote-tracking branch 'origin/main' into icicle-msm
alxiong Mar 1, 2024
b9c9ac9
hide gpu bench behind feature flag
alxiong Mar 4, 2024
4405329
introduce trait GPUCommit with less mem copy during conversion
alxiong Mar 6, 2024
c43ec12
specialize conversion and mont-based msm for bn254
alxiong Mar 6, 2024
964c63a
warmup for more accurate benchmark
alxiong Mar 6, 2024
4b98d1f
wip: add subslice into srs ref
alxiong Mar 6, 2024
bd5dd36
use mem::forget() to avoid double-free panic
alxiong Mar 7, 2024
ca5b494
improve test
alxiong Mar 7, 2024
4674eee
Merge remote-tracking branch 'origin/main' into icicle-msm
alxiong Mar 7, 2024
d150d66
update bench code
alxiong Mar 7, 2024
34bc815
test: high degree behind print-trace feature
alxiong Mar 7, 2024
24083ba
feat: add batch_commit for gpu
alxiong Mar 7, 2024
49663ab
update CHANGELOG
alxiong Mar 7, 2024
ae6a9ca
fix bench, add 2 handy apis, rename to gpu_commit()
alxiong Mar 8, 2024
9fdb299
nit: minor edit on profile test
alxiong Mar 8, 2024
7e5862c
improve scalar type conv by 2x for batch_commit, directly use par_iter
alxiong Mar 8, 2024
b371f65
fix trimmed_size logic, comparable test
alxiong Mar 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ jobs:
- name: Check all tests and binaries compilation
run: |
cargo check --workspace --tests --lib --bins
cargo check --workspace --all-features
cargo check --workspace --features 'std parallel test-srs test-apis'

- name: Check no_std support and WASM compilation
env:
Expand Down
125 changes: 79 additions & 46 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@
outputs = { self, nixpkgs, flake-utils, rust-overlay, pre-commit-hooks, ... }:
flake-utils.lib.eachDefaultSystem (system:
let
overlays = [
(import rust-overlay)
];
overlays = [ (import rust-overlay) ];
pkgs = import nixpkgs { inherit system overlays; };
nightlyToolchain = pkgs.rust-bin.selectLatestNightlyWith
(toolchain: toolchain.minimal.override { extensions = [ "rustfmt" ]; });
pkgsAllowUnfree = import nixpkgs {
inherit system;
config.allowUnfree = true;
};
gcc11 = pkgs.overrideCC pkgs.stdenv pkgs.gcc11;
nightlyToolchain = pkgs.rust-bin.selectLatestNightlyWith (toolchain:
toolchain.minimal.override { extensions = [ "rustfmt" ]; });

stableToolchain = pkgs.rust-bin.stable.latest.minimal.override {
extensions = [ "clippy" "llvm-tools-preview" "rust-src" ];
Expand All @@ -42,8 +45,52 @@
fi
exec ${stableToolchain}/bin/cargo "$@"
'';
in with pkgs;
{
baseShell = with pkgs;
clang15Stdenv.mkDerivation {
name = "clang15-nix-shell";
buildInputs = [
argbash
openssl
pkg-config
git
nixpkgs-fmt

cargo-with-nightly
stableToolchain
nightlyToolchain
cargo-sort
clang-tools_15
clangStdenv
llvm_15
] ++ lib.optionals stdenv.isDarwin
[ darwin.apple_sdk.frameworks.Security ];

CARGO_TARGET_DIR = "target/nix_rustc";

shellHook = ''
export RUST_BACKTRACE=full
export PATH="$PATH:$(pwd)/target/debug:$(pwd)/target/release"
# Prevent cargo aliases from using programs in `~/.cargo` to avoid conflicts with local rustup installations.
export CARGO_HOME=$HOME/.cargo-nix

# Ensure `cargo fmt` uses `rustfmt` from nightly.
export RUSTFMT="${nightlyToolchain}/bin/rustfmt"

export C_INCLUDE_PATH="${llvmPackages_15.libclang.lib}/lib/clang/${llvmPackages_15.libclang.version}/include"
export LIBCLANG_PATH=
export CC="${clang-tools_15.clang}/bin/clang"
export CXX="${clang-tools_15.clang}/bin/clang++"
export AR="${llvm_15}/bin/llvm-ar"
export CFLAGS="-mcpu=generic"

# by default choose u64_backend
export RUSTFLAGS='--cfg curve25519_dalek_backend="u64"'
''
# install pre-commit hooks
+ self.check.${system}.pre-commit-check.shellHook;
};
in
with pkgs; {
check = {
pre-commit-check = pre-commit-hooks.lib.${system}.run {
src = ./.;
Expand Down Expand Up @@ -72,48 +119,34 @@
entry = "cargo sort -w";
pass_filenames = false;
};
nixpkgs-fmt.enable = true;
};
};
};
devShell = clang15Stdenv.mkDerivation {
name = "clang15-nix-shell";
buildInputs = [
argbash
openssl
pkg-config
git

cargo-with-nightly
stableToolchain
nightlyToolchain
cargo-sort
clang-tools_15
clangStdenv
llvm_15
] ++ lib.optionals stdenv.isDarwin [ darwin.apple_sdk.frameworks.Security ];

CARGO_TARGET_DIR = "target/nix_rustc";

shellHook = ''
export RUST_BACKTRACE=full
export PATH="$PATH:$(pwd)/target/debug:$(pwd)/target/release"
# Prevent cargo aliases from using programs in `~/.cargo` to avoid conflicts with local rustup installations.
export CARGO_HOME=$HOME/.cargo-nix

# Ensure `cargo fmt` uses `rustfmt` from nightly.
export RUSTFMT="${nightlyToolchain}/bin/rustfmt"

export C_INCLUDE_PATH="${llvmPackages_15.libclang.lib}/lib/clang/${llvmPackages_15.libclang.version}/include"
export CC="${clang-tools_15.clang}/bin/clang"
export AR="${llvm_15}/bin/llvm-ar"
export CFLAGS="-mcpu=generic"
devShell = baseShell;
# extra dev shells
devShells = {
# run with `nix develop .#cudaShell`
cudaShell =
let cudatoolkit = pkgsAllowUnfree.cudaPackages_12_3.cudatoolkit;
in baseShell.overrideAttrs (oldAttrs: {
# for GPU/CUDA env (e.g. to run ICICLE code)
name = "cuda-env-shell";
buildInputs = oldAttrs.buildInputs
++ [ cmake cudatoolkit util-linux gcc11 ];
# CXX is overridden to use gcc as icicle-curves's build scripts need them
shellHook = oldAttrs.shellHook + ''

# by default choose u64_backend
export RUSTFLAGS='--cfg curve25519_dalek_backend="u64"'
''
# install pre-commit hooks
+ self.check.${system}.pre-commit-check.shellHook;
export PATH="${pkgs.gcc11}/bin:${cudatoolkit}/bin:${cudatoolkit}/nvvm/bin:$PATH"
export LD_LIBRARY_PATH=${cudatoolkit}/lib
export CUDA_PATH=${cudatoolkit}
export CPATH="${cudatoolkit}/include"
export LIBRARY_PATH="$LIBRARY_PATH:/lib"
export CMAKE_CUDA_COMPILER=$CUDA_PATH/bin/nvcc
export LIBCLANG_PATH=${llvmPackages_15.libclang.lib}/lib
export CFLAGS=""
'';
});
};
}
);
});
}
20 changes: 20 additions & 0 deletions primitives/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ generic-array = { version = "0", features = [
"serde",
] } # not a direct dependency, but we need serde
hashbrown = "0.14.3"
icicle-bls12-377 = { git = "https://github.com/ingonyama-zk/icicle.git", tag = "v1.5.0", optional = true, features = ["arkworks"] }
icicle-bls12-381 = { git = "https://github.com/ingonyama-zk/icicle.git", tag = "v1.5.0", optional = true, features = ["arkworks"] }
icicle-bn254 = { git = "https://github.com/ingonyama-zk/icicle.git", tag = "v1.5.0", optional = true, features = ["arkworks"] }
icicle-core = { git = "https://github.com/ingonyama-zk/icicle.git", tag = "v1.5.0", optional = true }
icicle-cuda-runtime = { git = "https://github.com/ingonyama-zk/icicle.git", tag = "v1.5.0", optional = true }
itertools = { workspace = true, features = ["use_alloc"] }
jf-relation = { path = "../relation", default-features = false }
jf-utils = { path = "../utilities" }
Expand Down Expand Up @@ -84,6 +89,12 @@ path = "benches/pcs_size.rs"
harness = false
required-features = ["test-srs"]

[[bench]]
name = "kzg-gpu"
path = "benches/kzg_gpu.rs"
harness = false
required-features = ["test-srs"]

[[bench]]
name = "reed-solomon"
path = "benches/reed_solomon.rs"
Expand Down Expand Up @@ -135,3 +146,12 @@ parallel = [
]
test-srs = []
seq-fk-23 = [] # FK23 without parallelism
icicle = [
"icicle-cuda-runtime",
"icicle-core",
"icicle-core/arkworks",
"icicle-bn254",
"icicle-bls12-381",
"icicle-bls12-377",
"parallel",
]
99 changes: 99 additions & 0 deletions primitives/benches/kzg_gpu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
//! This benchmark meant for MSM speed comparison between arkworks and
mrain marked this conversation as resolved.
Show resolved Hide resolved
//! GPU-accelerated code We use `UnivariateKzgPCS::commit()` as a proxy for MSM
//!
//! Run `cargo bench --bench kzg-gpu --features "test-srs icicle"`
use ark_bn254::Bn254;
#[cfg(feature = "icicle")]
use ark_ec::models::{short_weierstrass::Affine, CurveConfig};
use ark_ec::pairing::Pairing;
use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial};
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
#[cfg(feature = "icicle")]
use jf_primitives::icicle_deps::*;
use jf_primitives::pcs::{
prelude::{PolynomialCommitmentScheme, UnivariateKzgPCS},
StructuredReferenceString,
};
use jf_utils::test_rng;

const MIN_LOG_DEGREE: usize = 19;
const MAX_LOG_DEGREE: usize = 23;

/// running MSM using arkworks backend
pub fn kzg_ark<E: Pairing>(c: &mut Criterion) {
let mut group = c.benchmark_group("MSM with arkworks");
let mut rng = test_rng();

let supported_degree = 2usize.pow(MAX_LOG_DEGREE as u32);
let pp = UnivariateKzgPCS::<E>::gen_srs_for_testing(&mut rng, supported_degree).unwrap();

// setup for commit first
for log_degree in MIN_LOG_DEGREE..MAX_LOG_DEGREE {
let degree = 2usize.pow(log_degree as u32);
let (ck, _vk) = pp.trim(degree).unwrap();
let p = <DensePolynomial<E::ScalarField> as DenseUVPolynomial<E::ScalarField>>::rand(
degree, &mut rng,
);

group.bench_with_input(
BenchmarkId::from_parameter(log_degree),
&log_degree,
|b, _log_degree| b.iter(|| UnivariateKzgPCS::<E>::commit(&ck, &p).unwrap()),
);
}
group.finish();
}

/// running MSM using ICICLE backends
#[cfg(feature = "icicle")]
pub fn kzg_icicle<E, C>(c: &mut Criterion)
where
C: IcicleCurve + MSM<C>,
C::ScalarField: ArkConvertible<ArkEquivalent = E::ScalarField>,
C::BaseField: ArkConvertible<ArkEquivalent = <C::ArkSWConfig as CurveConfig>::BaseField>,
E: Pairing<G1Affine = Affine<<C as IcicleCurve>::ArkSWConfig>>,
{
let mut group = c.benchmark_group("MSM with ICICLE");
let mut rng = test_rng();

let supported_degree = 2usize.pow(MAX_LOG_DEGREE as u32);
let pp = UnivariateKzgPCS::<E>::gen_srs_for_testing(&mut rng, supported_degree).unwrap();
// TODO: (alex) figure out load longer SRS first, and only use part of it later
// currently it will error if the `scalars.len() % points.len() != 0`
// while we can tap into a slice behind the reference via
// `HostOrDeviceSlice[..]` which gives `&[T]` however msm() doesn't accept
// &[T], only `&HostOrDeviceSlice`

// setup for commit first
for log_degree in MIN_LOG_DEGREE..MAX_LOG_DEGREE {
let degree = 2usize.pow(log_degree as u32);
let (ck, _vk) = pp.trim(degree).unwrap();
let p = <DensePolynomial<E::ScalarField> as DenseUVPolynomial<E::ScalarField>>::rand(
degree, &mut rng,
);

group.bench_with_input(
BenchmarkId::from_parameter(log_degree),
&log_degree,
|b, _log_degree| {
b.iter(|| UnivariateKzgPCS::<E>::commit_with_gpu::<C>(&ck, &p).unwrap())
},
);
}
group.finish();
}

fn kzg_gpu_bn254(c: &mut Criterion) {
kzg_ark::<Bn254>(c);
#[cfg(feature = "icicle")]
kzg_icicle::<Bn254, icicle_bn254::curve::CurveCfg>(c);
}

criterion_group! {
name = kzg_gpu_benches;
config = Criterion::default().sample_size(10);
targets =
kzg_gpu_bn254,
}

criterion_main!(kzg_gpu_benches);
30 changes: 30 additions & 0 deletions primitives/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,34 @@ pub mod vdf;
pub mod vid;
pub mod vrf;

/// dependecies required for ICICLE-related code, group import for convenience
#[cfg(feature = "icicle")]
pub mod icicle_deps {
pub use icicle_core::{
curve::{Affine as IcicleAffine, Curve as IcicleCurve, Projective as IcicleProjective},
msm::{MSMConfig, MSM},
traits::{ArkConvertible, FieldImpl},
};
pub use icicle_cuda_runtime::{memory::HostOrDeviceSlice, stream::CudaStream};

/// curve-specific types both from arkworks and from ICICLE
/// including Pairing, CurveCfg, Fr, Fq etc.
pub mod curves {
pub use ark_bls12_381::Bls12_381;
pub use ark_bn254::Bn254;
pub use icicle_bls12_381::curve::CurveCfg as IcicleBls12_381;
pub use icicle_bn254::curve::CurveCfg as IcicleBn254;
}

// TODO: remove this after `warmup()` is added upstream
// https://github.com/ingonyama-zk/icicle/pull/422#issuecomment-1980881638
/// Create a new stream and warmup
pub fn warmup_new_stream() -> Result<CudaStream, ()> {
let stream = CudaStream::create().unwrap();
// TODO: consider using an error type?
let _warmup_bytes = HostOrDeviceSlice::<'_, u8>::cuda_malloc_async(1024, &stream).unwrap();
Ok(stream)
}
}

pub(crate) mod utils;
21 changes: 21 additions & 0 deletions primitives/src/pcs/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ use crate::errors::PrimitivesError;
use ark_serialize::SerializationError;
use ark_std::string::{String, ToString};
use displaydoc::Display;
#[cfg(feature = "icicle")]
use icicle_core::error::IcicleError;
#[cfg(feature = "icicle")]
use icicle_cuda_runtime::error::CudaError;

/// A `enum` specifying the possible failure modes of the PCS.
#[derive(Display, Debug)]
Expand All @@ -29,6 +33,9 @@ pub enum PCSError {
TranscriptError(TranscriptError),
/// Error from upstream dependencies: {0}
UpstreamError(String),
#[cfg(feature = "icicle")]
/// Error from ICICLE: {0}
IcicleError(String),
}

impl ark_std::error::Error for PCSError {}
Expand All @@ -50,3 +57,17 @@ impl From<PrimitivesError> for PCSError {
Self::UpstreamError(e.to_string())
}
}

#[cfg(feature = "icicle")]
impl From<IcicleError> for PCSError {
fn from(e: IcicleError) -> Self {
Self::IcicleError(ark_std::format!("{:?}", e))
}
}
#[cfg(feature = "icicle")]
impl From<CudaError> for PCSError {
fn from(e: CudaError) -> Self {
let icicle_err = IcicleError::from_cuda_error(e);
icicle_err.into()
}
}
Loading
Loading