From 6d6fc66e90da3eddcd9f37d93361b77a6596fed9 Mon Sep 17 00:00:00 2001 From: Chris Hegarty <62058229+ChrisHegarty@users.noreply.github.com> Date: Tue, 8 Oct 2024 10:45:35 +0100 Subject: [PATCH] Delegate xorBitCount to Lucene (#114249) Now that we're on Lucene 9.12 we don't need our own optimized xorBitCount, can just delegate to Lucene's optimized one (which is identical). --- .../vectors/ES815BitFlatVectorsFormat.java | 4 +- .../field/vectors/ByteBinaryDenseVector.java | 2 +- .../field/vectors/ByteKnnDenseVector.java | 2 +- .../script/field/vectors/ESVectorUtil.java | 73 ------------------- 4 files changed, 4 insertions(+), 77 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/ESVectorUtil.java diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java index 5969c9d5db6d7..f0f25bd702749 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java @@ -17,11 +17,11 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; -import org.elasticsearch.script.field.vectors.ESVectorUtil; import java.io.IOException; @@ -105,7 +105,7 @@ public RandomVectorScorer getRandomVectorScorer( } static float hammingScore(byte[] a, byte[] b) { - return ((a.length * Byte.SIZE) - ESVectorUtil.xorBitCount(a, b)) / (float) (a.length * Byte.SIZE); + return ((a.length * Byte.SIZE) - VectorUtil.xorBitCount(a, b)) / (float) (a.length * Byte.SIZE); } static class HammingVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java index a01d1fcbdb4ed..8f13ada2fd604 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java @@ -103,7 +103,7 @@ public double l1Norm(List queryVector) { @Override public int hamming(byte[] queryVector) { - return ESVectorUtil.xorBitCount(queryVector, vectorValue); + return VectorUtil.xorBitCount(queryVector, vectorValue); } @Override diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVector.java index a4219583824c3..42e5b5250199e 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVector.java @@ -104,7 +104,7 @@ public double l1Norm(List queryVector) { @Override public int hamming(byte[] queryVector) { - return ESVectorUtil.xorBitCount(queryVector, docVector); + return VectorUtil.xorBitCount(queryVector, docVector); } @Override diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ESVectorUtil.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ESVectorUtil.java deleted file mode 100644 index 045a0e5e75b04..0000000000000 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/ESVectorUtil.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.script.field.vectors; - -import org.apache.lucene.util.BitUtil; -import org.apache.lucene.util.Constants; - -/** - * This class consists of a single utility method that provides XOR bit count computed over signed bytes. - * Remove this class when Lucene version > 9.11 is released, and replace with Lucene's VectorUtil directly. - */ -public class ESVectorUtil { - - /** - * For xorBitCount we stride over the values as either 64-bits (long) or 32-bits (int) at a time. - * On ARM Long::bitCount is not vectorized, and therefore produces less than optimal code, when - * compared to Integer::bitCount. While Long::bitCount is optimal on x64. - */ - static final boolean XOR_BIT_COUNT_STRIDE_AS_INT = Constants.OS_ARCH.equals("aarch64"); - - /** - * XOR bit count computed over signed bytes. - * - * @param a bytes containing a vector - * @param b bytes containing another vector, of the same dimension - * @return the value of the XOR bit count of the two vectors - */ - public static int xorBitCount(byte[] a, byte[] b) { - if (a.length != b.length) { - throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); - } - if (XOR_BIT_COUNT_STRIDE_AS_INT) { - return xorBitCountInt(a, b); - } else { - return xorBitCountLong(a, b); - } - } - - /** XOR bit count striding over 4 bytes at a time. */ - static int xorBitCountInt(byte[] a, byte[] b) { - int distance = 0, i = 0; - for (final int upperBound = a.length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { - distance += Integer.bitCount((int) BitUtil.VH_NATIVE_INT.get(a, i) ^ (int) BitUtil.VH_NATIVE_INT.get(b, i)); - } - // tail: - for (; i < a.length; i++) { - distance += Integer.bitCount((a[i] ^ b[i]) & 0xFF); - } - return distance; - } - - /** XOR bit count striding over 8 bytes at a time. */ - static int xorBitCountLong(byte[] a, byte[] b) { - int distance = 0, i = 0; - for (final int upperBound = a.length & -Long.BYTES; i < upperBound; i += Long.BYTES) { - distance += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(a, i) ^ (long) BitUtil.VH_NATIVE_LONG.get(b, i)); - } - // tail: - for (; i < a.length; i++) { - distance += Integer.bitCount((a[i] ^ b[i]) & 0xFF); - } - return distance; - } - - private ESVectorUtil() {} -}