Skip to content

Commit

Permalink
Speed up vectorutil float scalar methods, unroll properly, use fma wh…
Browse files Browse the repository at this point in the history
…ere possible (#12737)

Co-authored-by: Uwe Schindler <[email protected]>
  • Loading branch information
rmuir and uschindler committed Nov 4, 2023
1 parent 71b3e4c commit 3a2ee7f
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,72 +17,46 @@

package org.apache.lucene.internal.vectorization;

import org.apache.lucene.util.Constants;
import org.apache.lucene.util.SuppressForbidden;

final class DefaultVectorUtilSupport implements VectorUtilSupport {

DefaultVectorUtilSupport() {}

// the way FMA should work! if available use it, otherwise fall back to mul/add
@SuppressForbidden(reason = "Uses FMA only where fast and carefully contained")
private static float fma(float a, float b, float c) {
if (Constants.HAS_FAST_SCALAR_FMA) {
return Math.fma(a, b, c);
} else {
return a * b + c;
}
}

@Override
public float dotProduct(float[] a, float[] b) {
float res = 0f;
/*
* If length of vector is larger than 8, we use unrolled dot product to accelerate the
* calculation.
*/
int i;
for (i = 0; i < a.length % 8; i++) {
res += b[i] * a[i];
}
if (a.length < 8) {
return res;
}
for (; i + 31 < a.length; i += 32) {
res +=
b[i + 0] * a[i + 0]
+ b[i + 1] * a[i + 1]
+ b[i + 2] * a[i + 2]
+ b[i + 3] * a[i + 3]
+ b[i + 4] * a[i + 4]
+ b[i + 5] * a[i + 5]
+ b[i + 6] * a[i + 6]
+ b[i + 7] * a[i + 7];
res +=
b[i + 8] * a[i + 8]
+ b[i + 9] * a[i + 9]
+ b[i + 10] * a[i + 10]
+ b[i + 11] * a[i + 11]
+ b[i + 12] * a[i + 12]
+ b[i + 13] * a[i + 13]
+ b[i + 14] * a[i + 14]
+ b[i + 15] * a[i + 15];
res +=
b[i + 16] * a[i + 16]
+ b[i + 17] * a[i + 17]
+ b[i + 18] * a[i + 18]
+ b[i + 19] * a[i + 19]
+ b[i + 20] * a[i + 20]
+ b[i + 21] * a[i + 21]
+ b[i + 22] * a[i + 22]
+ b[i + 23] * a[i + 23];
res +=
b[i + 24] * a[i + 24]
+ b[i + 25] * a[i + 25]
+ b[i + 26] * a[i + 26]
+ b[i + 27] * a[i + 27]
+ b[i + 28] * a[i + 28]
+ b[i + 29] * a[i + 29]
+ b[i + 30] * a[i + 30]
+ b[i + 31] * a[i + 31];
int i = 0;

// if the array is big, unroll it
if (a.length > 32) {
float acc1 = 0;
float acc2 = 0;
float acc3 = 0;
float acc4 = 0;
int upperBound = a.length & ~(4 - 1);
for (; i < upperBound; i += 4) {
acc1 = fma(a[i], b[i], acc1);
acc2 = fma(a[i + 1], b[i + 1], acc2);
acc3 = fma(a[i + 2], b[i + 2], acc3);
acc4 = fma(a[i + 3], b[i + 3], acc4);
}
res += acc1 + acc2 + acc3 + acc4;
}
for (; i + 7 < a.length; i += 8) {
res +=
b[i + 0] * a[i + 0]
+ b[i + 1] * a[i + 1]
+ b[i + 2] * a[i + 2]
+ b[i + 3] * a[i + 3]
+ b[i + 4] * a[i + 4]
+ b[i + 5] * a[i + 5]
+ b[i + 6] * a[i + 6]
+ b[i + 7] * a[i + 7];

for (; i < a.length; i++) {
res = fma(a[i], b[i], res);
}
return res;
}
Expand All @@ -92,50 +66,80 @@ public float cosine(float[] a, float[] b) {
float sum = 0.0f;
float norm1 = 0.0f;
float norm2 = 0.0f;
int dim = a.length;
int i = 0;

for (int i = 0; i < dim; i++) {
float elem1 = a[i];
float elem2 = b[i];
sum += elem1 * elem2;
norm1 += elem1 * elem1;
norm2 += elem2 * elem2;
// if the array is big, unroll it
if (a.length > 32) {
float sum1 = 0;
float sum2 = 0;
float norm1_1 = 0;
float norm1_2 = 0;
float norm2_1 = 0;
float norm2_2 = 0;

int upperBound = a.length & ~(2 - 1);
for (; i < upperBound; i += 2) {
// one
sum1 = fma(a[i], b[i], sum1);
norm1_1 = fma(a[i], a[i], norm1_1);
norm2_1 = fma(b[i], b[i], norm2_1);

// two
sum2 = fma(a[i + 1], b[i + 1], sum2);
norm1_2 = fma(a[i + 1], a[i + 1], norm1_2);
norm2_2 = fma(b[i + 1], b[i + 1], norm2_2);
}
sum += sum1 + sum2;
norm1 += norm1_1 + norm1_2;
norm2 += norm2_1 + norm2_2;
}

for (; i < a.length; i++) {
sum = fma(a[i], b[i], sum);
norm1 = fma(a[i], a[i], norm1);
norm2 = fma(b[i], b[i], norm2);
}
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
}

@Override
public float squareDistance(float[] a, float[] b) {
float squareSum = 0.0f;
int dim = a.length;
int i;
for (i = 0; i + 8 <= dim; i += 8) {
squareSum += squareDistanceUnrolled(a, b, i);
float res = 0;
int i = 0;

// if the array is big, unroll it
if (a.length > 32) {
float acc1 = 0;
float acc2 = 0;
float acc3 = 0;
float acc4 = 0;

int upperBound = a.length & ~(4 - 1);
for (; i < upperBound; i += 4) {
// one
float diff1 = a[i] - b[i];
acc1 = fma(diff1, diff1, acc1);

// two
float diff2 = a[i + 1] - b[i + 1];
acc2 = fma(diff2, diff2, acc2);

// three
float diff3 = a[i + 2] - b[i + 2];
acc3 = fma(diff3, diff3, acc3);

// four
float diff4 = a[i + 3] - b[i + 3];
acc4 = fma(diff4, diff4, acc4);
}
res += acc1 + acc2 + acc3 + acc4;
}
for (; i < dim; i++) {

for (; i < a.length; i++) {
float diff = a[i] - b[i];
squareSum += diff * diff;
res = fma(diff, diff, res);
}
return squareSum;
}

private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) {
float diff0 = v1[index + 0] - v2[index + 0];
float diff1 = v1[index + 1] - v2[index + 1];
float diff2 = v1[index + 2] - v2[index + 2];
float diff3 = v1[index + 3] - v2[index + 3];
float diff4 = v1[index + 4] - v2[index + 4];
float diff5 = v1[index + 5] - v2[index + 5];
float diff6 = v1[index + 6] - v2[index + 6];
float diff7 = v1[index + 7] - v2[index + 7];
return diff0 * diff0
+ diff1 * diff1
+ diff2 * diff2
+ diff3 * diff3
+ diff4 * diff4
+ diff5 * diff5
+ diff6 * diff6
+ diff7 * diff7;
return res;
}

@Override
Expand Down
77 changes: 70 additions & 7 deletions lucene/core/src/java/org/apache/lucene/util/Constants.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Objects;
import java.util.logging.Logger;

/** Some useful constants. */
Expand Down Expand Up @@ -91,12 +90,6 @@ private Constants() {} // can't construct
/** True iff running on a 64bit JVM */
public static final boolean JRE_IS_64BIT = is64Bit();

/** true iff we know fast FMA is supported, to deliver less error */
public static final boolean HAS_FAST_FMA =
(IS_CLIENT_VM == false)
&& Objects.equals(OS_ARCH, "amd64")
&& HotspotVMOptions.get("UseFMA").map(Boolean::valueOf).orElse(false);

private static boolean is64Bit() {
final String datamodel = getSysProp("sun.arch.data.model");
if (datamodel != null) {
Expand All @@ -106,6 +99,76 @@ private static boolean is64Bit() {
}
}

/** true if FMA likely means a cpu instruction and not BigDecimal logic */
private static final boolean HAS_FMA =
(IS_CLIENT_VM == false) && HotspotVMOptions.get("UseFMA").map(Boolean::valueOf).orElse(false);

/** maximum supported vectorsize */
private static final int MAX_VECTOR_SIZE =
HotspotVMOptions.get("MaxVectorSize").map(Integer::valueOf).orElse(0);

/** true for an AMD cpu with SSE4a instructions */
private static final boolean HAS_SSE4A =
HotspotVMOptions.get("UseXmmI2F").map(Boolean::valueOf).orElse(false);

/** true iff we know VFMA has faster throughput than separate vmul/vadd */
public static final boolean HAS_FAST_VECTOR_FMA = hasFastVectorFMA();

/** true iff we know FMA has faster throughput than separate mul/add */
public static final boolean HAS_FAST_SCALAR_FMA = hasFastScalarFMA();

private static boolean hasFastVectorFMA() {
if (HAS_FMA) {
String value = getSysProp("lucene.useVectorFMA", "auto");
if ("auto".equals(value)) {
// newer Neoverse cores have their act together
// the problem is just apple silicon (this is a practical heuristic)
if (OS_ARCH.equals("aarch64") && MAC_OS_X == false) {
return true;
}
// zen cores or newer, its a wash, turn it on as it doesn't hurt
// starts to yield gains for vectors only at zen4+
if (HAS_SSE4A && MAX_VECTOR_SIZE >= 32) {
return true;
}
// intel has their act together
if (OS_ARCH.equals("amd64") && HAS_SSE4A == false) {
return true;
}
} else {
return Boolean.parseBoolean(value);
}
}
// everyone else is slow, until proven otherwise by benchmarks
return false;
}

private static boolean hasFastScalarFMA() {
if (HAS_FMA) {
String value = getSysProp("lucene.useScalarFMA", "auto");
if ("auto".equals(value)) {
// newer Neoverse cores have their act together
// the problem is just apple silicon (this is a practical heuristic)
if (OS_ARCH.equals("aarch64") && MAC_OS_X == false) {
return true;
}
// latency becomes 4 for the Zen3 (0x19h), but still a wash
// until the Zen4 anyway, and big drop on previous zens:
if (HAS_SSE4A && MAX_VECTOR_SIZE >= 64) {
return true;
}
// intel has their act together
if (OS_ARCH.equals("amd64") && HAS_SSE4A == false) {
return true;
}
} else {
return Boolean.parseBoolean(value);
}
}
// everyone else is slow, until proven otherwise by benchmarks
return false;
}

private static String getSysProp(String property) {
try {
return doPrivileged(() -> System.getProperty(property));
Expand Down
26 changes: 25 additions & 1 deletion lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,31 @@
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
import org.apache.lucene.internal.vectorization.VectorizationProvider;

/** Utilities for computations with numeric arrays */
/**
* Utilities for computations with numeric arrays, especially algebraic operations like vector dot
* products. This class uses SIMD vectorization if the corresponding Java module is available and
* enabled. To enable vectorized code, pass {@code --add-modules jdk.incubator.vector} to Java's
* command line.
*
* <p>It will use CPU's <a href="https://en.wikipedia.org/wiki/Fused_multiply%E2%80%93add">FMA
* instructions</a> if it is known to perform faster than separate multiply+add. This requires at
* least Hotspot C2 enabled, which is the default for OpenJDK based JVMs.
*
* <p>To explicitly disable or enable FMA usage, pass the following system properties:
*
* <ul>
* <li>{@code -Dlucene.useScalarFMA=(auto|true|false)} for scalar operations
* <li>{@code -Dlucene.useVectorFMA=(auto|true|false)} for vectorized operations (with vector
* incubator module)
* </ul>
*
* <p>The default is {@code auto}, which enables this for known CPU types and JVM settings. If
* Hotspot C2 is disabled, FMA and vectorization are <strong>not</strong> used.
*
* <p>Vectorization and FMA is only supported for Hotspot-based JVMs; it won't work on OpenJ9-based
* JVMs unless they provide {@link com.sun.management.HotSpotDiagnosticMXBean}. Please also make
* sure that you have the {@code jdk.management} module enabled in modularized applications.
*/
public final class VectorUtil {

private static final VectorUtilSupport IMPL =
Expand Down
Loading

0 comments on commit 3a2ee7f

Please sign in to comment.