Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: MSM skip doubling when window has all zeros #152

Conversation

jonathanpwang
Copy link
Contributor

Closes #150

To be honest I did not have enough time to understand the full implementation of Cyclone MSM. However the principle that if an entire window has all 0 bits, then it can be totally skipped, seems like it can be carried over exactly the same.

@kilic kilic self-requested a review April 24, 2024 11:40
@ed255 ed255 self-requested a review April 26, 2024 14:24
ed255 added a commit that referenced this pull request May 3, 2024
@ed255
Copy link
Member

ed255 commented May 3, 2024

Do you have benchmark results for this change?

I tried running some and found the results surprising: there's a small speed improvement with k < 21, and at k > 21 there's a slowdown (that's only for the best_multiexp MSM implementation).

I ran a test where all the coefficients are 8 bits (so that the skipping of zeros can shine). These are my results

bits = 8
Start:   older k=18
End:     older k=18 ................................................................55.407ms
Start:   older_skip_zeros k=18
End:     older_skip_zeros k=18 .....................................................43.940ms
Start:   older k=19
End:     older k=19 ................................................................62.393ms
Start:   older_skip_zeros k=19
End:     older_skip_zeros k=19 .....................................................45.226ms
Start:   older k=20
End:     older k=20 ................................................................111.360ms
Start:   older_skip_zeros k=20
End:     older_skip_zeros k=20 .....................................................106.721ms
Start:   older k=21
End:     older k=21 ................................................................237.317ms
Start:   older_skip_zeros k=21
End:     older_skip_zeros k=21 .....................................................232.228ms
Start:   older k=22
End:     older k=22 ................................................................422.003ms
Start:   older_skip_zeros k=22
End:     older_skip_zeros k=22 .....................................................457.572ms

Tested via a6abbc8 with

cargo test --features print-trace --release test_msm_cross_small -- --nocapture

My CPU is AMD Ryzen 5 3600 6-Core Processor.

In case you benchmarked this, did you get different results? I wonder if my tests have any mistake that lead to this surprising result 🤔

@ed255
Copy link
Member

ed255 commented May 3, 2024

Here's a proposal for a different approach to skip zeros. Instead of preparing the coeffs in windows beforehand, just scan the coeffs and find the max amount of bytes they use, and then proceed with the algorithm only working with the max amount of bytes found.

This is what it looks like:

halo2curves/src/msm.rs

Lines 541 to 559 in db46631

let field_byte_size = coeffs[0].as_ref().len();
let mut acc_or = vec![0; field_byte_size];
for coeff in &coeffs {
for (acc_limb, limb) in acc_or.iter_mut().zip(coeff.as_ref().iter()) {
*acc_limb = *acc_limb | *limb;
}
}
let max_byte_size = field_byte_size
- acc_or
.iter()
.rev()
.position(|v| *v != 0)
.unwrap_or(field_byte_size);
if max_byte_size == 0 {
return;
}
let number_of_windows = max_byte_size * 8 as usize / c + 1;

Now if the max number of bytes is smaller than the field number of bytes, the number of windows to slide over will be smaller (thus skipping the windows that would pick the most significant bits, which were found to be zeroes).

These are the tests results I get

bits = 8
Start:   older k=18
End:     older k=18 ................................................................54.356ms
Start:   older_skip_zeros k=18
End:     older_skip_zeros k=18 .....................................................33.409ms
Start:   older k=19
End:     older k=19 ................................................................72.362ms
Start:   older_skip_zeros k=19
End:     older_skip_zeros k=19 .....................................................34.746ms
Start:   older k=20
End:     older k=20 ................................................................113.625ms
Start:   older_skip_zeros k=20
End:     older_skip_zeros k=20 .....................................................69.020ms
Start:   older k=21
End:     older k=21 ................................................................236.001ms
Start:   older_skip_zeros k=21
End:     older_skip_zeros k=21 .....................................................138.168ms
Start:   older k=22
End:     older k=22 ................................................................420.748ms
Start:   older_skip_zeros k=22
End:     older_skip_zeros k=22 .....................................................273.702ms

A very interesting next step that I was exploring a few weeks ago was adding metadata to the vector of coefficients to indicate how many bytes they use; because when writing a circuit we may already know how many bytes are used in certain columns; this way we skip the scanning to figure out max bytes. See privacy-scaling-explorations/halo2#315

@ed255
Copy link
Member

ed255 commented May 14, 2024

@jonathanpwang did you have time to take a look at this?

@jonathanpwang
Copy link
Contributor Author

Sorry I have not had a chance to look at this further. I had considered your suggestion before: I was worried that for values that weren't small bits, the initial scan would add an unwanted overhead. What happens if you run your second approach on full size scalars?

@ed255
Copy link
Member

ed255 commented Jun 4, 2024

I ran the test with full size scalars and found that the older implementation was slightly slower than the one I suggest, which doesn't make sense. The suggestion should have a small overhead. The test was running for each k, first the old implementation and then my suggested one. Then I tried swapping the order and the results were that my suggestion was slightly slower. So I'm thinking this way of comparing isn't very good; my guess is that the second one has advantage because some data is already in the cache? So I decided to work on testing this with proper benchmarks and will report the results once I have them.

@ed255
Copy link
Member

ed255 commented Jun 4, 2024

Did some benches with criterion to compare the original msm, your proposal and my proposal; with big values and 8 bit values. Here are the results:

msm_cmp/msm func=original, k=18, small=false/18
                        time:   [403.71 ms 405.78 ms 408.27 ms]
msm_cmp/msm func=skip_zeros_edu, k=18, small=false/18
                        time:   [401.60 ms 403.38 ms 405.09 ms]
msm_cmp/msm func=skip_zeros_jonathan, k=18, small=false/18
                        time:   [403.42 ms 408.09 ms 414.71 ms]

msm_cmp/msm func=original, k=19, small=false/19
                        time:   [732.35 ms 735.62 ms 738.80 ms]
msm_cmp/msm func=skip_zeros_edu, k=19, small=false/19
                        time:   [733.36 ms 737.46 ms 742.00 ms]
msm_cmp/msm func=skip_zeros_jonathan, k=19, small=false/19
                        time:   [733.89 ms 737.76 ms 741.98 ms]

msm_cmp/msm func=original, k=20, small=false/20
                        time:   [1.3542 s 1.3581 s 1.3628 s]
msm_cmp/msm func=skip_zeros_edu, k=20, small=false/20
                        time:   [1.3557 s 1.3579 s 1.3603 s]
msm_cmp/msm func=skip_zeros_jonathan, k=20, small=false/20
                        time:   [1.3712 s 1.3743 s 1.3774 s]

msm_cmp/msm func=original, k=21, small=false/21
                        time:   [2.4830 s 2.4977 s 2.5175 s]
msm_cmp/msm func=skip_zeros_edu, k=21, small=false/21
                        time:   [2.4815 s 2.4879 s 2.4959 s]
msm_cmp/msm func=skip_zeros_jonathan, k=21, small=false/21
                        time:   [2.5476 s 2.5581 s 2.5750 s]

msm_cmp/msm func=original, k=22, small=false/22
                        time:   [4.9031 s 4.9095 s 4.9170 s]
msm_cmp/msm func=skip_zeros_edu, k=22, small=false/22
                        time:   [4.9056 s 4.9143 s 4.9263 s]
msm_cmp/msm func=skip_zeros_jonathan, k=22, small=false/22
                        time:   [5.0393 s 5.0469 s 5.0555 s]

msm_cmp/msm func=original, k=18, small=true/18
                        time:   [29.141 ms 29.233 ms 29.313 ms]
msm_cmp/msm func=skip_zeros_edu, k=18, small=true/18
                        time:   [15.946 ms 16.053 ms 16.169 ms]
msm_cmp/msm func=skip_zeros_jonathan, k=18, small=true/18
                        time:   [19.605 ms 19.703 ms 19.856 ms]

msm_cmp/msm func=original, k=19, small=true/19
                        time:   [56.734 ms 57.711 ms 58.746 ms]
msm_cmp/msm func=skip_zeros_edu, k=19, small=true/19
                        time:   [33.380 ms 34.029 ms 34.473 ms]
msm_cmp/msm func=skip_zeros_jonathan, k=19, small=true/19
                        time:   [39.459 ms 40.114 ms 40.975 ms]

msm_cmp/msm func=original, k=20, small=true/20
                        time:   [112.87 ms 113.69 ms 114.41 ms]
msm_cmp/msm func=skip_zeros_edu, k=20, small=true/20
                        time:   [68.146 ms 68.810 ms 69.561 ms]
msm_cmp/msm func=skip_zeros_jonathan, k=20, small=true/20
                        time:   [90.475 ms 90.892 ms 91.278 ms]

msm_cmp/msm func=original, k=21, small=true/21
                        time:   [230.04 ms 231.36 ms 232.75 ms]
msm_cmp/msm func=skip_zeros_edu, k=21, small=true/21
                        time:   [137.37 ms 140.19 ms 143.54 ms]
msm_cmp/msm func=skip_zeros_jonathan, k=21, small=true/21
                        time:   [226.41 ms 226.94 ms 227.50 ms]

msm_cmp/msm func=original, k=22, small=true/22
                        time:   [426.24 ms 427.67 ms 429.04 ms]
msm_cmp/msm func=skip_zeros_edu, k=22, small=true/22
                        time:   [270.63 ms 273.88 ms 278.03 ms]
msm_cmp/msm func=skip_zeros_jonathan, k=22, small=true/22
                        time:   [457.01 ms 460.22 ms 465.00 ms]

The summary is:

  • On big values my proposal has a negligible difference with the original, yours has a very small overhead
  • On 8 bit values my proposal is significantly faster than yours. At k=22 your proposal gets behind the original one.

To reproduce checkout this commit d1f79a5 and run with

cargo bench --bench msm msm_cmp

@jonathanpwang
Copy link
Contributor Author

Yes interesting, thanks for the benchmarks. I'm guessing your scan just preloads stuff into cache so it doesn't have much slowdown.

I am in favor of going with yours.

@ed255 ed255 mentioned this pull request Jul 2, 2024
@davidnevadoc
Copy link
Contributor

Superseded by #168

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Continue multiexp_serial skips doubling when all bits are zero.
3 participants