From e10fc3c90dc18da0b6dd02a06113899e0be0c5de Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Mon, 2 Dec 2024 12:19:03 -0500 Subject: [PATCH] Speed up bit compared with floats or bytes script operations (#117199) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Instead of doing an "if" statement, which doesn't lend itself to vectorization, I switched to expand to the bits and multiply the 1s and 0s. This led to a marginal speed improvement on ARM. I expect that Panama vector could be used here to be even faster, but I didn't want to spend anymore time on this for the time being. ``` Benchmark (dims) Mode Cnt Score Error Units IpBitVectorScorerBenchmark.dotProductByteIfStatement 768 thrpt 5 2.952 ± 0.026 ops/us IpBitVectorScorerBenchmark.dotProductByteUnwrap 768 thrpt 5 4.017 ± 0.068 ops/us IpBitVectorScorerBenchmark.dotProductFloatIfStatement 768 thrpt 5 2.987 ± 0.124 ops/us IpBitVectorScorerBenchmark.dotProductFloatUnwrap 768 thrpt 5 4.726 ± 0.136 ops/us ``` Benchmark I used. https://gist.github.com/benwtrent/b0edb3975d2f03356c1a5ea84c72abc9 --- docs/changelog/117199.yaml | 5 ++ .../elasticsearch/simdvec/ESVectorUtil.java | 23 +------ .../DefaultESVectorUtilSupport.java | 65 +++++++++++++++++++ .../vectorization/ESVectorUtilSupport.java | 4 ++ .../PanamaESVectorUtilSupport.java | 10 +++ 5 files changed, 86 insertions(+), 21 deletions(-) create mode 100644 docs/changelog/117199.yaml diff --git a/docs/changelog/117199.yaml b/docs/changelog/117199.yaml new file mode 100644 index 0000000000000..b685e98b61f6b --- /dev/null +++ b/docs/changelog/117199.yaml @@ -0,0 +1,5 @@ +pr: 117199 +summary: Speed up bit compared with floats or bytes script operations +area: Vector Search +type: enhancement +issues: [] diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java index 2f4743a47a14a..7fe475e86a2f5 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java @@ -61,17 +61,7 @@ public static int ipByteBit(byte[] q, byte[] d) { if (q.length != d.length * Byte.SIZE) { throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length); } - int result = 0; - // now combine the two vectors, summing the byte dimensions where the bit in d is `1` - for (int i = 0; i < d.length; i++) { - byte mask = d[i]; - for (int j = Byte.SIZE - 1; j >= 0; j--) { - if ((mask & (1 << j)) != 0) { - result += q[i * Byte.SIZE + Byte.SIZE - 1 - j]; - } - } - } - return result; + return IMPL.ipByteBit(q, d); } /** @@ -87,16 +77,7 @@ public static float ipFloatBit(float[] q, byte[] d) { if (q.length != d.length * Byte.SIZE) { throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length); } - float result = 0; - for (int i = 0; i < d.length; i++) { - byte mask = d[i]; - for (int j = Byte.SIZE - 1; j >= 0; j--) { - if ((mask & (1 << j)) != 0) { - result += q[i * Byte.SIZE + Byte.SIZE - 1 - j]; - } - } - } - return result; + return IMPL.ipFloatBit(q, d); } /** diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java index 4a08096119d6a..00381c8c3fb2f 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java @@ -10,9 +10,18 @@ package org.elasticsearch.simdvec.internal.vectorization; import org.apache.lucene.util.BitUtil; +import org.apache.lucene.util.Constants; final class DefaultESVectorUtilSupport implements ESVectorUtilSupport { + private static float fma(float a, float b, float c) { + if (Constants.HAS_FAST_SCALAR_FMA) { + return Math.fma(a, b, c); + } else { + return a * b + c; + } + } + DefaultESVectorUtilSupport() {} @Override @@ -20,6 +29,62 @@ public long ipByteBinByte(byte[] q, byte[] d) { return ipByteBinByteImpl(q, d); } + @Override + public int ipByteBit(byte[] q, byte[] d) { + return ipByteBitImpl(q, d); + } + + @Override + public float ipFloatBit(float[] q, byte[] d) { + return ipFloatBitImpl(q, d); + } + + public static int ipByteBitImpl(byte[] q, byte[] d) { + assert q.length == d.length * Byte.SIZE; + int acc0 = 0; + int acc1 = 0; + int acc2 = 0; + int acc3 = 0; + // now combine the two vectors, summing the byte dimensions where the bit in d is `1` + for (int i = 0; i < d.length; i++) { + byte mask = d[i]; + // Make sure its just 1 or 0 + + acc0 += q[i * Byte.SIZE + 0] * ((mask >> 7) & 1); + acc1 += q[i * Byte.SIZE + 1] * ((mask >> 6) & 1); + acc2 += q[i * Byte.SIZE + 2] * ((mask >> 5) & 1); + acc3 += q[i * Byte.SIZE + 3] * ((mask >> 4) & 1); + + acc0 += q[i * Byte.SIZE + 4] * ((mask >> 3) & 1); + acc1 += q[i * Byte.SIZE + 5] * ((mask >> 2) & 1); + acc2 += q[i * Byte.SIZE + 6] * ((mask >> 1) & 1); + acc3 += q[i * Byte.SIZE + 7] * ((mask >> 0) & 1); + } + return acc0 + acc1 + acc2 + acc3; + } + + public static float ipFloatBitImpl(float[] q, byte[] d) { + assert q.length == d.length * Byte.SIZE; + float acc0 = 0; + float acc1 = 0; + float acc2 = 0; + float acc3 = 0; + // now combine the two vectors, summing the byte dimensions where the bit in d is `1` + for (int i = 0; i < d.length; i++) { + byte mask = d[i]; + acc0 = fma(q[i * Byte.SIZE + 0], (mask >> 7) & 1, acc0); + acc1 = fma(q[i * Byte.SIZE + 1], (mask >> 6) & 1, acc1); + acc2 = fma(q[i * Byte.SIZE + 2], (mask >> 5) & 1, acc2); + acc3 = fma(q[i * Byte.SIZE + 3], (mask >> 4) & 1, acc3); + + acc0 = fma(q[i * Byte.SIZE + 4], (mask >> 3) & 1, acc0); + acc1 = fma(q[i * Byte.SIZE + 5], (mask >> 2) & 1, acc1); + acc2 = fma(q[i * Byte.SIZE + 6], (mask >> 1) & 1, acc2); + acc3 = fma(q[i * Byte.SIZE + 7], (mask >> 0) & 1, acc3); + } + return acc0 + acc1 + acc2 + acc3; + } + public static long ipByteBinByteImpl(byte[] q, byte[] d) { long ret = 0; int size = d.length; diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java index d7611173ca693..6938bffec5f37 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java @@ -14,4 +14,8 @@ public interface ESVectorUtilSupport { short B_QUERY = 4; long ipByteBinByte(byte[] q, byte[] d); + + int ipByteBit(byte[] q, byte[] d); + + float ipFloatBit(float[] q, byte[] d); } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java index 0e5827d046736..4de33643258e4 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java @@ -48,6 +48,16 @@ public long ipByteBinByte(byte[] q, byte[] d) { return DefaultESVectorUtilSupport.ipByteBinByteImpl(q, d); } + @Override + public int ipByteBit(byte[] q, byte[] d) { + return DefaultESVectorUtilSupport.ipByteBitImpl(q, d); + } + + @Override + public float ipFloatBit(float[] q, byte[] d) { + return DefaultESVectorUtilSupport.ipFloatBitImpl(q, d); + } + private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128; private static final VectorSpecies BYTE_SPECIES_256 = ByteVector.SPECIES_256;