Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Implement Dot Product Similarity Distance algorithm #120

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

Taenerys
Copy link

Why this change?

  • MuopDB as of the time of writing has implemented Euclean Distance (L2) algorithm for calculating the distance between vectors.
  • We would like to support multiple distance calculators, with Dot Product Similarity being the first additional algorithm to be added.

What changes have been made?

  • dot_product_similarity.rs so far has the implementation of Scalar and SIMD type
  • Unit tests for basic case and consistency test
  • Benchmark test: Will add the current results below.

Creating this draft PR for first round review - there is definitely room for improvements so would appreciate any inputs/feedbacks you can give, thanks so much!

@Taenerys Taenerys requested a review from hicder November 17, 2024 22:49
@Taenerys
Copy link
Author

Taenerys commented Nov 17, 2024

Overall Observation from the benchmark:

  • SIMD generally outperforms Scalar, matching our expectations since SIMD makes use of parallel processing
  • There seems to be quite an amount of outliers, need to find out why

Current results of Benchmark test:

Dot Product Similarity/Scalar/8
                        time:   [6.1531 ns 6.1579 ns 6.1651 ns]
Found 5 outliers among 100 measurements (5.00%)
  3 (3.00%) high mild
  2 (2.00%) high severe
Dot Product Similarity/SIMD/8
                        time:   [5.0491 ns 5.0561 ns 5.0654 ns]
Found 5 outliers among 100 measurements (5.00%)
  3 (3.00%) high mild
  2 (2.00%) high severe
Dot Product Similarity/Calculate/8
                        time:   [2.8731 ns 2.8785 ns 2.8866 ns]
Found 9 outliers among 100 measurements (9.00%)
  5 (5.00%) high mild
  4 (4.00%) high severe
Dot Product Similarity/Scalar/16
                        time:   [4.5122 ns 4.5208 ns 4.5346 ns]
Found 5 outliers among 100 measurements (5.00%)
  3 (3.00%) high mild
  2 (2.00%) high severe
Dot Product Similarity/SIMD/16
                        time:   [6.2826 ns 6.2849 ns 6.2875 ns]
Found 6 outliers among 100 measurements (6.00%)
  5 (5.00%) high mild
  1 (1.00%) high severe
Dot Product Similarity/Calculate/16
                        time:   [4.0977 ns 4.1004 ns 4.1043 ns]
Found 6 outliers among 100 measurements (6.00%)
  1 (1.00%) high mild
  5 (5.00%) high severe
Dot Product Similarity/Scalar/32
                        time:   [10.396 ns 10.399 ns 10.403 ns]
Found 5 outliers among 100 measurements (5.00%)
  3 (3.00%) high mild
  2 (2.00%) high severe
Dot Product Similarity/SIMD/32
                        time:   [8.9192 ns 8.9363 ns 8.9664 ns]
Found 4 outliers among 100 measurements (4.00%)
  2 (2.00%) high mild
  2 (2.00%) high severe
Dot Product Similarity/Calculate/32
                        time:   [9.2737 ns 9.2842 ns 9.2945 ns]
Found 5 outliers among 100 measurements (5.00%)
  2 (2.00%) low mild
  2 (2.00%) high mild
  1 (1.00%) high severe
Dot Product Similarity/Scalar/64
                        time:   [26.888 ns 26.898 ns 26.908 ns]
Found 4 outliers among 100 measurements (4.00%)
  2 (2.00%) high mild
  2 (2.00%) high severe
Dot Product Similarity/SIMD/64
                        time:   [14.238 ns 14.246 ns 14.257 ns]
Found 7 outliers among 100 measurements (7.00%)
  6 (6.00%) high mild
  1 (1.00%) high severe
Dot Product Similarity/Calculate/64
                        time:   [14.617 ns 14.630 ns 14.643 ns]
Dot Product Similarity/Scalar/128
                        time:   [73.736 ns 74.079 ns 74.728 ns]
Found 8 outliers among 100 measurements (8.00%)
  5 (5.00%) high mild
  3 (3.00%) high severe
Dot Product Similarity/SIMD/128
                        time:   [24.946 ns 24.971 ns 24.998 ns]
Found 6 outliers among 100 measurements (6.00%)
  4 (4.00%) high mild
  2 (2.00%) high severe
Dot Product Similarity/Calculate/128
                        time:   [25.387 ns 25.599 ns 25.918 ns]
Found 5 outliers among 100 measurements (5.00%)
  5 (5.00%) high severe
Dot Product Similarity/Scalar/256
                        time:   [180.13 ns 182.63 ns 186.01 ns]
Found 9 outliers among 100 measurements (9.00%)
  1 (1.00%) high mild
  8 (8.00%) high severe
Dot Product Similarity/SIMD/256
                        time:   [46.379 ns 46.567 ns 46.820 ns]
Found 7 outliers among 100 measurements (7.00%)
  1 (1.00%) high mild
  6 (6.00%) high severe
Dot Product Similarity/Calculate/256
                        time:   [47.329 ns 47.895 ns 48.667 ns]
Found 8 outliers among 100 measurements (8.00%)
  2 (2.00%) high mild
  6 (6.00%) high severe
Dot Product Similarity/Scalar/512
                        time:   [472.15 ns 475.68 ns 481.15 ns]
Found 7 outliers among 100 measurements (7.00%)
  2 (2.00%) high mild
  5 (5.00%) high severe
Dot Product Similarity/SIMD/512
                        time:   [89.693 ns 90.001 ns 90.338 ns]
Found 12 outliers among 100 measurements (12.00%)
  8 (8.00%) high mild
  4 (4.00%) high severe
Dot Product Similarity/Calculate/512
                        time:   [89.198 ns 90.196 ns 91.740 ns]
Found 6 outliers among 100 measurements (6.00%)
  1 (1.00%) high mild
  5 (5.00%) high severe
Dot Product Similarity/Scalar/1024
                        time:   [1.0973 µs 1.0983 µs 1.1000 µs]
Found 5 outliers among 100 measurements (5.00%)
  4 (4.00%) high mild
  1 (1.00%) high severe
Dot Product Similarity/SIMD/1024
                        time:   [181.61 ns 181.98 ns 182.54 ns]
Found 2 outliers among 100 measurements (2.00%)
  1 (1.00%) high mild
  1 (1.00%) high severe
Dot Product Similarity/Calculate/1024
                        time:   [180.00 ns 180.07 ns 180.16 ns]
Found 5 outliers among 100 measurements (5.00%)
  4 (4.00%) high mild
  1 (1.00%) high severe
Dot Product Similarity/Scalar/2048
                        time:   [2.3779 µs 2.3905 µs 2.4054 µs]
Found 7 outliers among 100 measurements (7.00%)
  1 (1.00%) high mild
  6 (6.00%) high severe
Dot Product Similarity/SIMD/2048
                        time:   [351.17 ns 351.62 ns 352.19 ns]
Dot Product Similarity/Calculate/2048
                        time:   [350.29 ns 350.44 ns 350.61 ns]
Found 10 outliers among 100 measurements (10.00%)
  6 (6.00%) high mild
  4 (4.00%) high severe
Dot Product Similarity/Scalar/4096
                        time:   [4.9188 µs 4.9391 µs 4.9640 µs]
Found 7 outliers among 100 measurements (7.00%)
  1 (1.00%) high mild
  6 (6.00%) high severe
Dot Product Similarity/SIMD/4096
                        time:   [744.64 ns 777.34 ns 821.19 ns]
Found 14 outliers among 100 measurements (14.00%)
  1 (1.00%) high mild
  13 (13.00%) high severe
Dot Product Similarity/Calculate/4096
                        time:   [691.46 ns 697.79 ns 708.44 ns]
Found 7 outliers among 100 measurements (7.00%)
  4 (4.00%) high mild
  3 (3.00%) high severe
Dot Product Similarity/Scalar/384
                        time:   [320.12 ns 332.84 ns 349.74 ns]
Found 17 outliers among 100 measurements (17.00%)
  5 (5.00%) high mild
  12 (12.00%) high severe
Dot Product Similarity/SIMD/384
                        time:   [69.103 ns 70.707 ns 72.822 ns]
Found 11 outliers among 100 measurements (11.00%)
  5 (5.00%) high mild
  6 (6.00%) high severe
Dot Product Similarity/Calculate/384
                        time:   [69.548 ns 72.232 ns 75.753 ns]
Found 16 outliers among 100 measurements (16.00%)
  1 (1.00%) high mild
  15 (15.00%) high severe
Dot Product Similarity/Scalar/768
                        time:   [798.23 ns 825.61 ns 859.45 ns]
Found 11 outliers among 100 measurements (11.00%)
  3 (3.00%) high mild
  8 (8.00%) high severe
Dot Product Similarity/SIMD/768
                        time:   [131.86 ns 132.00 ns 132.14 ns]
Found 2 outliers among 100 measurements (2.00%)
  1 (1.00%) high mild
  1 (1.00%) high severe
Dot Product Similarity/Calculate/768
                        time:   [132.65 ns 133.56 ns 134.90 ns]
Found 5 outliers among 100 measurements (5.00%)
  1 (1.00%) high mild
  4 (4.00%) high severe
Dot Product Similarity/Scalar/1536
                        time:   [1.7333 µs 1.7358 µs 1.7383 µs]
Found 1 outliers among 100 measurements (1.00%)
  1 (1.00%) high mild
Dot Product Similarity/SIMD/1536
                        time:   [265.87 ns 267.65 ns 270.28 ns]
Found 4 outliers among 100 measurements (4.00%)
  4 (4.00%) high severe
Dot Product Similarity/Calculate/1536
                        time:   [264.83 ns 264.90 ns 264.98 ns]
Found 8 outliers among 100 measurements (8.00%)
  4 (4.00%) high mild
  4 (4.00%) high severe
Dot Product Similarity/Scalar/3072
                        time:   [3.6107 µs 3.6124 µs 3.6145 µs]
Found 12 outliers among 100 measurements (12.00%)
  5 (5.00%) high mild
  7 (7.00%) high severe
Dot Product Similarity/SIMD/3072
                        time:   [520.31 ns 520.43 ns 520.57 ns]
Found 2 outliers among 100 measurements (2.00%)
  2 (2.00%) high severe
Dot Product Similarity/Calculate/3072
                        time:   [520.66 ns 521.04 ns 521.48 ns]
Found 12 outliers among 100 measurements (12.00%)
  5 (5.00%) high mild
  7 (7.00%) high severe

@Taenerys Taenerys linked an issue Nov 17, 2024 that may be closed by this pull request
@Taenerys
Copy link
Author

just read @hicder 's merged PR

this PR has a great idea of optimizing the code by inline computations directly within SIMD registers, going to give this a try in my implementation PR

Comment on lines +40 to +45
while i + step <= a.len() {
let a_slice = f32x8::from_slice(&a[i..]);
let b_slice = f32x8::from_slice(&b[i..]);
self.dist_simd_8 += a_slice * b_slice;
i += step;
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
while i + step <= a.len() {
let a_slice = f32x8::from_slice(&a[i..]);
let b_slice = f32x8::from_slice(&b[i..]);
self.dist_simd_8 += a_slice * b_slice;
i += step;
}
while i + 16 <= a.len() && i + 16 < b.len(){
let a_slice = f32x16::from_slice(&a[i..i + 16]);
let b_slice = f32x16::from_slice(&b[i..i + 16]);
self.dist_simd_16 += a_slice * b_slice;
i += step;
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for Dot product Similarity
2 participants